# Evaluating your SAE

Code based off Rob Graham's ([themachinefan](https://github.com/themachinefan)) SAE evaluation code.

In [1]:
import torch
import torchvision

import plotly.express as px

from tqdm import tqdm

import einops

import numpy as np
import os

import requests


# Setup

In [3]:
from dataclasses import dataclass
from vit_prisma.sae.config import VisionModelSAERunnerConfig

@dataclass
class EvalConfig(VisionModelSAERunnerConfig):
    sae_path: str = '/network/scratch/s/sonia.joseph/sae_checkpoints/tinyclip_40M_mlp_out/6bdc5c08-wkcn-TinyCLIP-ViT-40M-32-Text-19M-LAION400M-expansion-16/n_images_260014.pt'
    model_name: str = "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M"
    model_type: str =  "clip"
    patch_size: str = 32

    dataset_path = "/network/scratch/s/sonia.joseph/datasets/kaggle_datasets"
    dataset_train_path: str = "/network/scratch/s/sonia.joseph/datasets/kaggle_datasets/ILSVRC/Data/CLS-LOC/train"
    dataset_val_path: str = "/network/scratch/s/sonia.joseph/datasets/kaggle_datasets/ILSVRC/Data/CLS-LOC/val"

    verbose: bool = True

    device: bool = 'cuda'

    eval_max: int = 50_000 # 50_000
    batch_size: int = 32

    # make the max image output folder a subfolder of the sae path


    @property
    def max_image_output_folder(self) -> str:
        # Get the base directory of sae_checkpoints
        sae_base_dir = os.path.dirname(os.path.dirname(self.sae_path))
        
        # Get the name of the original SAE checkpoint folder
        sae_folder_name = os.path.basename(os.path.dirname(self.sae_path))
        
        # Create a new folder path in sae_checkpoints/images with the original name
        output_folder = os.path.join(sae_base_dir, 'max_images', sae_folder_name)
        output_folder = os.path.join(output_folder, f"layer_{self.hook_point_layer}") # Add layer number

        
        # Ensure the directory exists
        os.makedirs(output_folder, exist_ok=True)
        
        return output_folder

cfg = EvalConfig()

n_tokens_per_buffer (millions): 0.032
Lower bound: n_contexts_per_buffer (millions): 0.00064
Total training steps: 15869
Total training images: 1300000
Total wandb updates: 158
Expansion factor: 16
n_tokens_per_feature_sampling_window (millions): 61.44
n_tokens_per_dead_feature_window (millions): 1024.0
Using Ghost Grads.
We will reset the sparsity calculation 52 times.
Number tokens in sparsity calculation window: 1.23e+06
Gradient clipping with max_norm=1.0
Using SAE initialization method: encoder_transpose_decoder


In [4]:
torch.set_grad_enabled(False)




<torch.autograd.grad_mode.set_grad_enabled at 0x7f04dd492920>

## Load model

In [5]:
from vit_prisma.models.base_vit import HookedViT

model_name = "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M"
model = HookedViT.from_pretrained(model_name, is_timm=False, is_clip=True).to(cfg.device)
 

{'n_layers': 12, 'd_model': 512, 'd_head': 64, 'model_name': '', 'n_heads': 8, 'd_mlp': 2048, 'activation_name': 'gelu', 'eps': 1e-05, 'original_architecture': 'vit_clip_vision_encoder', 'initializer_range': 0.02, 'n_channels': 3, 'patch_size': 32, 'image_size': 224, 'n_classes': 512, 'n_params': None, 'layer_norm_pre': True, 'return_type': 'class_logits'}
LayerNorm folded.
Centered weights writing to residual stream
Loaded pretrained model wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M into HookedTransformer


## Load datasets

In [6]:
# load dataset
from vit_prisma.utils.data_utils.imagenet_utils import setup_imagenet_paths
from vit_prisma.dataloaders.imagenet_dataset import get_imagenet_transforms_clip, ImageNetValidationDataset


if cfg.model_type == 'clip':
    data_transforms = get_imagenet_transforms_clip(cfg.model_name)
else:
    raise ValueError("Invalid model type")
imagenet_paths = setup_imagenet_paths(cfg.dataset_path)
train_data = torchvision.datasets.ImageFolder(cfg.dataset_train_path, transform=data_transforms)
val_data = ImageNetValidationDataset(cfg.dataset_val_path, 
                                imagenet_paths['label_strings'], 
                                imagenet_paths['val_labels'], 
                                data_transforms,
                                return_index=True,
)
val_data_visualize = ImageNetValidationDataset(cfg.dataset_val_path, 
                                imagenet_paths['label_strings'], 
                                imagenet_paths['val_labels'],
                                torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),]), return_index=True)

print(f"Validation data length: {len(val_data)}") if cfg.verbose else None


In [12]:
from vit_prisma.sae.training.activations_store import VisionActivationsStore
# import dataloader
from torch.utils.data import DataLoader

# activations_loader = VisionActivationsStore(cfg, model, train_data, eval_dataset=val_data)
val_dataloader = DataLoader(val_data, batch_size=cfg.batch_size, shuffle=False, num_workers=4)


## Load pretrained SAE to evaluate

In [None]:
from vit_prisma.sae.sae import SparseAutoencoder
sparse_autoencoder = SparseAutoencoder(cfg).load_from_pretrained(cfg.sae_path)
sparse_autoencoder.to(cfg.device)
sparse_autoencoder.eval()  # prevents error if we're expecting a dead neuron mask for who 


SparseAutoencoder(
  (hook_sae_in): HookPoint()
  (hook_hidden_pre): HookPoint()
  (hook_hidden_post): HookPoint()
  (hook_sae_out): HookPoint()
)

# Evaluate sparsity

## Average L0 Test

Calculate L0 on just a single batch as a sanity check. We'll do more intensive calculations later.

In [None]:
import plotly.express as px
import plotly.graph_objects as go
import torch
import numpy as np
from tqdm import tqdm
import pandas as pd

# Initialize variables to store running statistics
sum_l0 = None
count = 0
all_l0_values = []

# Set the number of samples to keep for the box plot
n_samples = 10000

# Iterate over the entire validation dataset
with torch.no_grad():
    for batch_tokens, labels, indices in tqdm(val_dataloader, desc="Processing batches"):
        batch_tokens = batch_tokens.to(cfg.device)
        _, cache = model.run_with_cache(batch_tokens, names_filter = sparse_autoencoder.cfg.hook_point)
        sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
            cache[sparse_autoencoder.cfg.hook_point].to(cfg.device)
        )
        del cache

        # Calculate L0 per patch for this batch
        l0 = (feature_acts > 0).float().sum(-1).detach().cpu().numpy()
        
        # Update running statistics
        if sum_l0 is None:
            sum_l0 = l0.sum(axis=0)
        else:
            sum_l0 += l0.sum(axis=0)
        count += l0.shape[0]
        
        # Store a random sample of L0 values for plots
        if len(all_l0_values) < n_samples:
            n_to_add = min(n_samples - len(all_l0_values), l0.shape[0])
            all_l0_values.extend(l0[:n_to_add])

# Calculate the final mean
mean_l0 = sum_l0 / count

print("Mean L0 shape:", mean_l0.shape)
print("Overall average L0:", mean_l0.mean())

# Create patch numbers
patch_numbers = np.arange(1, mean_l0.shape[0] + 1)

# Create an interactive scatter plot
fig = go.Figure()

# Add scatter plot
fig.add_trace(go.Scatter(
    x=patch_numbers,
    y=mean_l0,
    mode='markers',
    name='Average L0',
    marker=dict(color='#FFC0CB', size=8)
))

# Update layout
fig.update_layout(
    title='Average L0 by Patch Number',
    xaxis_title='Patch Number',
    yaxis_title='Average L0',
    plot_bgcolor='white',
    paper_bgcolor='white',
    font=dict(color='black'),
    title_font=dict(size=24, color='black'),
    xaxis=dict(
        tickmode='linear',
        dtick=max(1, len(patch_numbers) // 20),
        gridcolor='lightgrey',
        zerolinecolor='black'
    ),
    yaxis=dict(
        gridcolor='lightgrey',
        zerolinecolor='black'
    ),
    showlegend=False
)

fig.show()

# Create a heatmap of L0 values
heatmap_data = np.array(all_l0_values)
heatmap_fig = px.imshow(heatmap_data,
                        labels=dict(x="Patch Number", y="Sample", color="L0 Value"),
                        title="Heatmap of L0 Values")
heatmap_fig.update_layout(
    plot_bgcolor='white',
    paper_bgcolor='white',
    font=dict(color='black'),
    title_font=dict(size=24, color='black')
)
heatmap_fig.show()

# Create a violin plot
df = pd.DataFrame(all_l0_values, columns=[f'Patch {i}' for i in patch_numbers])
df_melted = df.melt(var_name='Patch', value_name='L0')

violin_fig = px.violin(df_melted, x='Patch', y='L0', box=True, points="all",
                       labels={'Patch': 'Patch Number', 'L0': 'L0 Value'},
                       title='Distribution of L0 Values by Patch')
violin_fig.update_traces(meanline_visible=True)
violin_fig.update_layout(
    plot_bgcolor='white',
    paper_bgcolor='white',
    font=dict(color='black'),
    title_font=dict(size=24, color='black'),
    xaxis=dict(gridcolor='lightgrey', zerolinecolor='black'),
    yaxis=dict(gridcolor='lightgrey', zerolinecolor='black')
)
violin_fig.show()

Processing batches:   2%|▏         | 24/1563 [00:02<01:31, 16.87it/s]

In [36]:
avg_l0_per_patch = np.mean(all_l0_per_patch, axis=0)
avg_l0_per_patch


array([77.805145, 50.25    , 55.695312, 61.925552, 61.654873, 61.181984,
       56.148438, 49.869026, 55.212315, 59.26011 , 63.0841  , 62.843292,
       63.030792, 60.63419 , 54.27114 , 61.11581 , 60.1034  , 62.76103 ,
       61.982536, 62.74816 , 61.641083, 57.955883, 59.20542 , 60.409927,
       60.19899 , 61.856617, 61.483456, 60.130974, 55.88741 , 55.595127,
       59.142464, 60.311123, 60.919117, 61.05101 , 58.80285 , 53.325367,
       54.06434 , 55.607998, 56.627758, 57.193016, 57.949448, 56.22748 ,
       53.373623, 49.465992, 55.170036, 56.524815, 55.20542 , 57.29044 ,
       55.710938, 49.25965 ], dtype=float32)

In [47]:
import plotly.express as px
import plotly.graph_objects as go
import torch
import numpy as np
from tqdm import tqdm
import pandas as pd

# Assuming all_l0_per_patch is already a numpy array after processing the validation dataset

# Calculate the average L0 per patch across all batches
avg_l0_per_patch = np.mean(all_l0_per_patch, axis=0)

# Calculate the standard error of the mean
sem_l0_per_patch = np.std(all_l0_per_patch, axis=0) / np.sqrt(all_l0_per_patch.shape[0])

print("Average L0 shape:", avg_l0_per_patch.shape)
print("Overall average L0 per patch:", avg_l0_per_patch.mean())

# Create patch numbers
patch_numbers = np.arange(1, len(avg_l0_per_patch) + 1)

# Create a DataFrame for the plot
df = pd.DataFrame({
    'Patch Number': patch_numbers,
    'Average L0': avg_l0_per_patch,
    'SEM': sem_l0_per_patch
})

# Create an interactive scatter plot with error bars
fig = go.Figure()

# Add scatter plot
fig.add_trace(go.Scatter(
    x=df['Patch Number'],
    y=df['Average L0'],
    mode='markers',
    name='Average L0',
    marker=dict(color='#FFC0CB', size=8),
    error_y=dict(
        type='data',
        array=df['SEM'],
        visible=True,
        color='black'
    )
))

# Update layout
fig.update_layout(
    title='Average L0 by Patch Number with Error Bars',
    xaxis_title='Patch Number',
    yaxis_title='Average L0',
    plot_bgcolor='white',
    paper_bgcolor='white',
    font=dict(color='black'),
    title_font=dict(size=24, color='black'),
    xaxis=dict(
        tickmode='linear',
        dtick=max(1, len(patch_numbers) // 20),
        gridcolor='lightgrey',
        zerolinecolor='black'
    ),
    yaxis=dict(
        gridcolor='lightgrey',
        zerolinecolor='black'
    ),
    showlegend=False
)

fig.show()

# Create a histogram of average L0 values per patch
hist_fig = px.histogram(avg_l0_per_patch, title="Distribution of Average L0 values per patch")
hist_fig.update_traces(marker=dict(color='#FFC0CB', line=dict(color='black', width=1)))
hist_fig.update_layout(
    plot_bgcolor='white',
    paper_bgcolor='white',
    font=dict(color='black'),
    title=dict(font=dict(size=24, color='black')),
    xaxis=dict(gridcolor='lightgrey', zerolinecolor='black'),
    yaxis=dict(gridcolor='lightgrey', zerolinecolor='black')
)
hist_fig.show()

# Create a box plot to show the distribution of L0 values for each patch
all_l0_per_patch_flat = all_l0_per_patch.T
box_fig = px.box(data_frame=pd.DataFrame(all_l0_per_patch_flat, columns=[f'Patch {i+1}' for i in range(all_l0_per_patch_flat.shape[1])]))
box_fig.update_traces(marker=dict(color='#FFC0CB'), line=dict(color='black'))
box_fig.update_layout(
    title='Distribution of L0 values for each patch',
    xaxis_title='Patch Number',
    yaxis_title='L0',
    xaxis=dict(tickmode='linear', dtick=max(1, all_l0_per_patch_flat.shape[0] // 20),
               gridcolor='lightgrey', zerolinecolor='black'),
    yaxis=dict(gridcolor='lightgrey', zerolinecolor='black'),
    plot_bgcolor='white',
    paper_bgcolor='white',
    font=dict(color='black'),
    title_font=dict(size=24, color='black')
)
box_fig.show()

Average L0 shape: (50,)
Overall average L0 per patch: 58.03682


In [50]:
all_l0_per_patch_flat.shape

(50, 48)

In [49]:
# Create a box plot to show the distribution of L0 values for each patch
all_l0_per_patch_flat = all_l0_per_patch.T
box_fig = px.box(data_frame=pd.DataFrame(all_l0_per_patch_flat, columns=[f'Patch {i+1}' for i in range(all_l0_per_patch_flat.shape[1])]))
box_fig.update_traces(marker=dict(color='#FFC0CB'), line=dict(color='black'))
box_fig.update_layout(
    title='Distribution of L0 values for each patch',
    xaxis_title='Patch Number',
    yaxis_title='L0',
    xaxis=dict(
        tickmode='linear',
        tickvals=list(range(1, all_l0_per_patch_flat.shape[1] + 1)),
        ticktext=[f'Patch {i+1}' for i in range(all_l0_per_patch_flat.shape[1])],
        gridcolor='lightgrey',
        zerolinecolor='black'
    ),
    yaxis=dict(gridcolor='lightgrey', zerolinecolor='black'),
    plot_bgcolor='white',
    paper_bgcolor='white',
    font=dict(color='black'),
    title_font=dict(size=24, color='black')
)
box_fig.show()

In [45]:
import plotly.express as px
import plotly.graph_objects as go
import torch
import numpy as np
from tqdm import tqdm
import pandas as pd

# Assuming all_l0_per_patch is already a numpy array after processing the validation dataset

# Calculate the average L0 per patch across all batches
avg_l0_per_patch = np.mean(all_l0_per_patch, axis=0)

# Calculate the standard error of the mean
sem_l0_per_patch = np.std(all_l0_per_patch, axis=0) / np.sqrt(all_l0_per_patch.shape[0])

print("Average L0 shape:", avg_l0_per_patch.shape)
print("Overall average L0 per patch:", avg_l0_per_patch.mean())

# Create patch numbers
patch_numbers = np.arange(1, len(avg_l0_per_patch) + 1)

# Create a DataFrame for the plot
df = pd.DataFrame({
    'Patch Number': patch_numbers,
    'Average L0': avg_l0_per_patch,
    'SEM': sem_l0_per_patch
})

# Create an interactive scatter plot with error bars
fig = go.Figure()

# Add scatter plot
fig.add_trace(go.Scatter(
    x=df['Patch Number'],
    y=df['Average L0'],
    mode='markers',
    name='Average L0',
    marker=dict(color='#FF6B9D', size=8),
    error_y=dict(
        type='data',
        array=df['SEM'],
        visible=True,
        color='#9C4A88'
    )
))

# Update layout
fig.update_layout(
    title='Average L0 by Patch Number with Error Bars',
    xaxis_title='Patch Number',
    yaxis_title='Average L0',
    plot_bgcolor='#F5F5F5',
    paper_bgcolor='#F5F5F5',
    font=dict(color='#1A1A1A'),
    title_font=dict(size=24, color='#1A1A1A'),
    xaxis=dict(
        tickmode='linear',
        dtick=max(1, len(patch_numbers) // 20),
        gridcolor='#E0E0E0',
        zerolinecolor='#1A1A1A'
    ),
    yaxis=dict(
        gridcolor='#E0E0E0',
        zerolinecolor='#1A1A1A'
    ),
    showlegend=False
)

fig.show()

# Create a histogram of average L0 values per patch
hist_fig = px.histogram(avg_l0_per_patch, title="Distribution of Average L0 values per patch")
hist_fig.update_traces(marker=dict(color='#9C4A88', line=dict(color='#1A1A1A', width=1)))
hist_fig.update_layout(
    plot_bgcolor='#F5F5F5',
    paper_bgcolor='#F5F5F5',
    font=dict(color='#1A1A1A'),
    title=dict(font=dict(size=24, color='#1A1A1A')),
    xaxis=dict(gridcolor='#E0E0E0', zerolinecolor='#1A1A1A'),
    yaxis=dict(gridcolor='#E0E0E0', zerolinecolor='#1A1A1A')
)
hist_fig.show()

# Create a box plot to show the distribution of L0 values for each patch
all_l0_per_patch_flat = all_l0_per_patch.T
box_fig = px.box(data_frame=pd.DataFrame(all_l0_per_patch_flat, columns=[f'Patch {i+1}' for i in range(all_l0_per_patch_flat.shape[1])]))
box_fig.update_traces(marker=dict(color='#9C4A88'), line=dict(color='#1A1A1A'))
box_fig.update_layout(
    title='Distribution of L0 values for each patch',
    xaxis_title='Patch Number',
    yaxis_title='L0',
    xaxis=dict(tickmode='linear', dtick=max(1, all_l0_per_patch_flat.shape[0] // 20),
               gridcolor='#E0E0E0', zerolinecolor='#1A1A1A'),
    yaxis=dict(gridcolor='#E0E0E0', zerolinecolor='#1A1A1A'),
    plot_bgcolor='#F5F5F5',
    paper_bgcolor='#F5F5F5',
    font=dict(color='#1A1A1A'),
    title_font=dict(size=24, color='#1A1A1A')
)
box_fig.show()

Average L0 shape: (50,)
Overall average L0 per patch: 58.03682


## Get feature probability

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

@torch.no_grad()
def get_feature_probability(images, model, sparse_autoencoder):
    _, cache = model.run_with_cache(images, names_filter=[sparse_autoencoder.cfg.hook_point])
    sae_out, feature_acts, loss, mse_loss, l1_loss, _ = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point]
    )

    # Flatten first two dimensions (batch, position) to get a 2D tensor of activations
    return (feature_acts.abs() > 0).float().flatten(0, 1)

def process_dataset(val_dataloader, model, sparse_autoencoder, cfg):
    total_acts = None
    total_tokens = 0
    
    for idx, batch in tqdm(enumerate(val_dataloader), total=cfg.eval_max//cfg.batch_size):
        images = batch[0]

        images = images.to(cfg.device)
        sae_activations = get_feature_probability(images, model, sparse_autoencoder)
        
        if total_acts is None:
            total_acts = sae_activations.sum(0)
        else:
            total_acts += sae_activations.sum(0)
        
        total_tokens += sae_activations.shape[0]
        
        # if total_tokens >= cfg.eval_max:
        #     break
    
    return total_acts, total_tokens

def calculate_log_frequencies(total_acts, total_tokens):
    feature_probs = total_acts / total_tokens
    log_feature_probs = torch.log10(feature_probs)
    return log_feature_probs.cpu().numpy()

# def plot_histogram(log_frequencies, num_bins=100): # Note: black edged histograms look great!
#     plt.figure(figsize=(12, 6))
#     plt.hist(log_frequencies, bins=num_bins, edgecolor='black')
#     plt.xlabel('Log10 Feature Frequency')
#     plt.ylabel('Count')
#     plt.title('Log Feature Density Histogram')
#     plt.grid(True, alpha=0.3)
#     plt.show()


# Main execution
total_acts, total_tokens = process_dataset(val_dataloader, model, sparse_autoencoder, cfg)

log_frequencies = calculate_log_frequencies(total_acts, total_tokens)

print(f"Total tokens processed: {total_tokens}")
print(f"Average activations per token: {total_acts.sum().item() / total_tokens:.4f}")


  0%|          | 0/1562 [00:00<?, ?it/s]

  2%|▏         | 31/1562 [00:03<02:50,  8.98it/s]


Total tokens processed: 51200
Average activations per token: 22.2705


In [None]:
def plot_histogram_px(log_frequencies, num_bins=100):
    fig = px.histogram(
        x=log_frequencies,
        nbins=num_bins,
        labels={'x': 'Log10 Feature Frequency', 'y': 'Count'},
        title='Log Feature Density Histogram',
        opacity=0.7,
    )
    fig.update_layout(
        bargap=0.1,
        xaxis_title='Log10 Feature Frequency',
        yaxis_title='Count',
        plot_bgcolor='rgba(240, 240, 240, 0.8)',  # Light gray background
        xaxis=dict(showgrid=True, gridwidth=1, gridcolor='White'),
        yaxis=dict(showgrid=True, gridwidth=1, gridcolor='White'),
    )
    fig.show()

    
plot_histogram_px(log_frequencies, num_bins=240)

In [None]:
log_freq = torch.Tensor(log_frequencies)

# minimum and maximum log_freq
min_log_freq = log_freq.min().item()
max_log_freq = log_freq.max().item()

print(f"Minimum log frequency: {min_log_freq:.4f}")
print(f"Maximum log frequency: {max_log_freq:.4f}")

Minimum log frequency: -inf
Maximum log frequency: -0.5430


In [None]:

# helper functions
update_layout_set = {"xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth", "yaxis_gridcolor", "showlegend", "xaxis_tickmode", "yaxis_tickmode", "margin", "xaxis_visible", "yaxis_visible", "bargap", "bargroupgap", "coloraxis_showscale"}
def to_numpy(tensor):
    """
    Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays.
    """
    if isinstance(tensor, np.ndarray):
        return tensor
    elif isinstance(tensor, (list, tuple)):
        array = np.array(tensor)
        return array
    elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)):
        return tensor.detach().cpu().numpy()
    elif isinstance(tensor, (int, float, bool, str)):
        return np.array(tensor)
    else:
        raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}")

def hist(tensor, save_name, show=True, renderer=None, **kwargs):
    '''
    '''
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "bargap" not in kwargs_post:
        kwargs_post["bargap"] = 0.1
    if "margin" in kwargs_post and isinstance(kwargs_post["margin"], int):
        kwargs_post["margin"] = dict.fromkeys(list("tblr"), kwargs_post["margin"])

    histogram_fig = px.histogram(x=to_numpy(tensor), **kwargs_pre)
    histogram_fig.update_layout(**kwargs_post)

    # Save the figure as a PNG file
    # histogram_fig.write_image(os.path.join(OUTPUT_FOLDER, f"{save_name}.png"))
    if show:
        px.histogram(x=to_numpy(tensor), **kwargs_pre).update_layout(**kwargs_post).show(renderer)



In [None]:
def visualize_sparsities(log_freq, conditions, condition_texts, name):
    # Visualise sparsities for each instance
    hist(
        log_freq,
        f"{name}_frequency_histogram",
        show=True,
        title=f"{name} Log Frequency of Features",
        labels={"x": "log<sub>10</sub>(freq)"},
        histnorm="percent",
        template="ggplot2"
    )

    #TODO these conditions need to be tuned to distribution of your data!

# Define intervals based on the specified ranges
intervals = [
    (-8, -6),
    (-6, -5),
    (-5, -4),
    (-4, -3),
    (-3, -2),
    (-2, -1),
    (-float('inf'), -8),  # This covers the [-8, -4] range and below
    (-1, float('inf'))    # This covers everything above -1
]

conditions = [torch.logical_and(log_freq >= lower, log_freq < upper) for lower, upper in intervals]
condition_texts = [
    f"TOTAL_logfreq_[{lower},{upper}]" for lower, upper in intervals
]

# Replace infinity with appropriate text for readability
condition_texts[-2] = condition_texts[-2].replace('-inf', '-∞')
condition_texts[-1] = condition_texts[-1].replace('inf', '∞')

visualize_sparsities(log_freq, conditions, condition_texts, "TOTAL")

# Reconstruction loss

In [None]:

def get_reconstruction_loss(
    images,
    model,
    autoencoder,
):
    '''
    Returns the reconstruction loss of each autoencoder instance on the given batch of tokens (i.e.
    the L2 loss between the activations and the autoencoder's reconstructions, averaged over all tokens).
    '''

    logits, cache = model.run_with_cache(images, names_filter=[sparse_autoencoder.cfg.hook_point])
    sae_out, feature_acts, loss, mse_loss, l1_loss, mse_loss_ghost_resid = sparse_autoencoder(
        cache[sparse_autoencoder.cfg.hook_point]
    )

    # Print out the avg L2 norm of activations
    print("Avg L2 norm of acts: ", cache[sparse_autoencoder.cfg.hook_point].pow(2).mean().item())

    # Print out the cosine similarity between original neuron activations & reconstructions (averaged over neurons)
    print("Avg cos sim of neuron reconstructions: ", torch.cosine_similarity(einops.rearrange( cache[sparse_autoencoder.cfg.hook_point], "batch seq d_mlp -> (batch seq) d_mlp"),
                                                                              einops.rearrange( sae_out, "batch seq d_mlp -> (batch seq) d_mlp"),
                                                                                dim=0).mean(-1).tolist())
    print("l1", l1_loss.sum().item())
    return mse_loss.item()

this_max = 4
count = 0
print(sparse_autoencoder.cfg.hook_point)
for batch_idx, (total_images, total_labels, total_indices) in enumerate(val_dataloader):
        total_images = total_images.to(cfg.device)
        reconstruction_loss = get_reconstruction_loss(total_images, model, sparse_autoencoder)
        print("mse", reconstruction_loss)
        if batch_idx >= this_max:
            break

blocks.9.hook_mlp_out
Avg L2 norm of acts:  0.031900838017463684
Avg cos sim of neuron reconstructions:  0.96854567527771
l1 0.05544371157884598
mse 0.0005733075668103993
Avg L2 norm of acts:  0.03172605112195015
Avg cos sim of neuron reconstructions:  0.9678786396980286
l1 0.05587046593427658
mse 0.0005859808879904449
Avg L2 norm of acts:  0.03149449825286865
Avg cos sim of neuron reconstructions:  0.9668527841567993
l1 0.056405287235975266
mse 0.0005890288157388568
Avg L2 norm of acts:  0.030216442421078682
Avg cos sim of neuron reconstructions:  0.9674074649810791
l1 0.053135380148887634
mse 0.0005861470708623528
Avg L2 norm of acts:  0.031549710780382156
Avg cos sim of neuron reconstructions:  0.9655236005783081
l1 0.056917689740657806
mse 0.0006197023903951049


# Substitution Loss

In [None]:
import torch
from transformers import CLIPModel, CLIPProcessor


device = 'cuda'

def get_text_embeddings(model_name, original_text, batch_size=32):
    vanilla_model = CLIPModel.from_pretrained(model_name)
    
    processor = CLIPProcessor.from_pretrained(model_name, do_rescale=False)

    # Split the text into batches
    text_batches = [original_text[i:i+batch_size] for i in range(0, len(original_text), batch_size)]

    all_embeddings = []

    for batch in text_batches:
        inputs = processor(text=batch, return_tensors='pt', padding=True, truncation=True, max_length=77)
        # inputs = {k: v.to(cfg.device) for k, v in inputs.items()}

        with torch.no_grad():
            text_embeddings = vanilla_model.get_text_features(**inputs)

        text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
        all_embeddings.append(text_embeddings)

    # Concatenate all batches
    final_embeddings = torch.cat(all_embeddings, dim=0)

    return final_embeddings



def get_similarity(image_features, text_features, k=5, device='cuda'):
  image_features = image_features.to(device)
  text_features = text_features.to(device)

  softmax_values = (image_features @ text_features.T).softmax(dim=-1)
  top_k_values, top_k_indices = torch.topk(softmax_values, k, dim=-1)
  return softmax_values, top_k_indices



In [None]:
def get_text_labels(name='wordbank'):
    """
Loads the library of logit labels from a GitHub URL.

Returns:
list: A list of string labels.
"""
    if name == 'wordbank':
        url = "https://raw.githubusercontent.com/yossigandelsman/clip_text_span/main/text_descriptions/image_descriptions_general.txt"
        try:
            # Fetch the content from the URL
            response = requests.get(url)
            response.raise_for_status()  # Raise an exception for bad status codes
            
            # Split the content into lines and strip whitespace
            all_labels = [line.strip() for line in response.text.splitlines()]
            
            print(f"Number of labels loaded: {len(all_labels)}")
            print(f"First 5 labels: {all_labels[:5]}")
            return all_labels
        
        except requests.RequestException as e:
            print(f"An error occurred while fetching the labels: {e}")
            return []
    elif name == 'imagenet':
        from vit_prisma.dataloaders.imagenet_dataset import get_imagenet_text_labels
        return get_imagenet_text_labels()
    else:
        raise ValueError(f"Invalid label set name: {name}")

In [None]:
import os # For debugging purposes
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

In [None]:
# Import f
import torch.nn.functional as F
# import partial
from functools import partial
# import Any
from typing import Any, List

def zero_ablate_hook(activations: torch.Tensor, hook: Any):
    activations = torch.zeros_like(activations)
    return activations

def print_model_device(model):
    print(f"Model is on device: {next(model.parameters()).device}")

def safe_custom_cross_entropy_loss(logits, targets):
    """
    Compute cross entropy loss manually with additional safety checks.
    """
    # Move tensors to CPU for safer computation
    logits = logits.detach().cpu()
    targets = targets.detach().cpu()
    
    batch_size, num_classes = logits.shape
    
    print(f"Logits shape: {logits.shape}")
    print(f"Targets shape: {targets.shape}")
    print(f"Targets min: {targets.min()}, max: {targets.max()}")
    
    # Check for NaN or Inf values
    if torch.isnan(logits).any() or torch.isinf(logits).any():
        print("Warning: NaN or Inf values detected in logits")
        return float('nan')
    
    # Check if targets are within the valid range
    if targets.min() < 0 or targets.max() >= num_classes:
        print(f"Error: Target labels out of range. Should be in [0, {num_classes-1}]")
        return float('nan')
    
    # Numerical stability - subtract max logit
    max_logits, _ = torch.max(logits, dim=1, keepdim=True)
    logits_stable = logits - max_logits
    
    # Compute exponentials
    exp_logits = torch.exp(logits_stable)
    
    # Compute sum of exponentials
    sum_exp_logits = torch.sum(exp_logits, dim=1, keepdim=True)
    
    # Check for zero sum of exponentials
    if (sum_exp_logits == 0).any():
        print("Warning: Zero sum of exponentials detected")
        return float('nan')
    
    # Compute log of sum of exponentials
    log_sum_exp = torch.log(sum_exp_logits) + max_logits
    
    # Compute logits of correct classes
    correct_logits = logits[torch.arange(batch_size), targets]
    
    # Compute per-sample losses
    losses = log_sum_exp.squeeze() - correct_logits
    
    # Compute mean loss
    loss = torch.mean(losses)
    return loss

@torch.no_grad()
def get_recons_loss(
    sparse_autoencoder: SparseAutoencoder,
    model: HookedViT,
    batch_tokens: torch.Tensor,
    gt_labels: torch.Tensor,
    all_labels: List[str],
    text_embeddings: torch.Tensor,
    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
):
    # Move model to device if it's not already there
    model = model.to(device)
    
    # Move all tensors to the same device
    batch_tokens = batch_tokens.to(device)
    gt_labels = gt_labels.to(device)
    text_embeddings = text_embeddings.to(device)

    # Get image embeddings
    image_embeddings, _ = model.run_with_cache(batch_tokens)

    # Calculate similarity scores
    softmax_values, top_k_indices = get_similarity(image_embeddings, text_embeddings, device=device)

    # # Print ground truth and top 5 predictions for each image in the batch
    # for idx in range(len(batch_tokens)):
    #     print(f"Ground Truth: {all_labels[gt_labels[idx]]}")
    #     print("Top 5 predictions:")
    #     for k in range(5):
    #         predicted_idx = top_k_indices[idx, k].item()
    #         predicted_label = all_labels[predicted_idx]
    #         print(f"  {k+1}. {predicted_label}")
    #     print()

    # Calculate cross-entropy loss
    loss = F.cross_entropy(softmax_values, gt_labels)
    # Safely extract the loss value
    loss_value = loss.item() if torch.isfinite(loss).all() else float('nan')


    head_index = sparse_autoencoder.cfg.hook_point_head_index
    hook_point = sparse_autoencoder.cfg.hook_point

    def standard_replacement_hook(activations: torch.Tensor, hook: Any):
        activations = sparse_autoencoder.forward(activations)[0].to(activations.dtype)
        return activations

    def head_replacement_hook(activations: torch.Tensor, hook: Any):
        new_activations = sparse_autoencoder.forward(activations[:, :, head_index])[0].to(activations.dtype)
        activations[:, :, head_index] = new_activations
        return activations

    replacement_hook = standard_replacement_hook if head_index is None else head_replacement_hook

    recons_image_embeddings = model.run_with_hooks(
        batch_tokens,
        fwd_hooks=[(hook_point, partial(replacement_hook))],
    )
    recons_softmax_values, _ = get_similarity(recons_image_embeddings, text_embeddings, device=device)
    recons_loss = F.cross_entropy(recons_softmax_values, gt_labels)

    zero_abl_image_embeddings = model.run_with_hooks(
        batch_tokens, fwd_hooks=[(hook_point, zero_ablate_hook)]
    )
    zero_abl_softmax_values, _ = get_similarity(zero_abl_image_embeddings, text_embeddings, device=device)
    zero_abl_loss = F.cross_entropy(zero_abl_softmax_values, gt_labels)

    score = (zero_abl_loss - recons_loss) / (zero_abl_loss - loss)

    return score, loss, recons_loss, zero_abl_loss


all_labels = get_text_labels('imagenet') # wordbank or imagenet

In [None]:
text_embeddings = get_text_embeddings(model_name, all_labels)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()

with torch.no_grad():
    for batch_tokens, gt_labels, indices in val_dataloader:
        score, loss, recons_loss, zero_abl_loss = get_recons_loss(
            sparse_autoencoder, 
            model, 
            batch_tokens, 
            gt_labels, 
            all_labels, 
            text_embeddings,
            device
        )
        # print all values
        print(f"Score: {score:.4f}")
        print(f"Loss: {loss:.4f}")
        print(f"Reconstruction Loss: {recons_loss:.4f}")
        print(f"Zero Ablation Loss: {zero_abl_loss:.4f}")
        break

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Score: 1.0066
Loss: 6.9060
Reconstruction Loss: 6.9059
Zero Ablation Loss: 6.9075


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.empty_cache()

total_score = 0
total_loss = 0
total_recons_loss = 0
total_zero_abl_loss = 0
total_batches = 0

with torch.no_grad():
    for batch_tokens, gt_labels, indices in tqdm(val_dataloader):
        score, loss, recons_loss, zero_abl_loss = get_recons_loss(
            sparse_autoencoder, 
            model, 
            batch_tokens, 
            gt_labels, 
            all_labels, 
            text_embeddings,
            device
        )
        
        total_score += score
        total_loss += loss
        total_recons_loss += recons_loss
        total_zero_abl_loss += zero_abl_loss
        total_batches += 1

        # # Optional: Print progress every 10 batches
        # if total_batches % 1000 == 0:
        #     print(f"Processed {total_batches} batches...")

# Calculate averages
avg_score = total_score / total_batches
avg_loss = total_loss / total_batches
avg_recons_loss = total_recons_loss / total_batches
avg_zero_abl_loss = total_zero_abl_loss / total_batches

# Print final results
print("\nFinal Results:")
print(f"Average Score: {avg_score:.4f}")
print(f"Average Loss: {avg_loss:.4f}")
print(f"Average Reconstruction Loss: {avg_recons_loss:.4f}")
print(f"Average Zero Ablation Loss: {avg_zero_abl_loss:.4f}")

  0%|          | 0/1563 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling


Final Results:
Average Score: 1.0022
Average Loss: 6.9063
Average Reconstruction Loss: 6.9063
Average Zero Ablation Loss: 6.9074





# Maximally activating images

In [None]:

# get random features from different bins

log_freq = log_freq.to('cuda')

interesting_features_indices = []
interesting_features_values = []
interesting_features_category = []
number_features_per = 50
for condition, condition_text in zip(conditions, condition_texts):
    

    potential_indices = torch.nonzero(condition, as_tuple=True)[0]

    # Shuffle these indices and select a subset
    sampled_indices = potential_indices[torch.randperm(len(potential_indices))[:number_features_per]].to('cuda')
    

    values = log_freq[sampled_indices]

    interesting_features_indices = interesting_features_indices + sampled_indices.tolist()
    interesting_features_values = interesting_features_values + values.tolist()

    interesting_features_category = interesting_features_category + [f"{condition_text}"]*len(sampled_indices)


# for v,i, c in zip(interesting_features_indices, interesting_features_values, interesting_features_category):
#     print(c, v,i)

print(set(interesting_features_category))



{'TOTAL_logfreq_[-2,-1]', 'TOTAL_logfreq_[-5,-4]', 'TOTAL_logfreq_[-∞,-8]', 'TOTAL_logfreq_[-1,∞]', 'TOTAL_logfreq_[-4,-3]', 'TOTAL_logfreq_[-3,-2]'}


In [None]:
from typing import List, Dict, Tuple
import torch
import einops
from tqdm import tqdm

@torch.no_grad()
def compute_feature_activations(
    images: torch.Tensor,
    model: torch.nn.Module,
    sparse_autoencoder: torch.nn.Module,
    encoder_weights: torch.Tensor,
    encoder_biases: torch.Tensor,
    feature_ids: List[int],
    feature_categories: List[str],
    top_k: int = 10
) -> Dict[int, Tuple[torch.Tensor, torch.Tensor]]:
    """
    Compute the highest activating tokens for given features in a batch of images.
    
    Args:
        images: Input images
        model: The main model
        sparse_autoencoder: The sparse autoencoder
        encoder_weights: Encoder weights for selected features
        encoder_biases: Encoder biases for selected features
        feature_ids: List of feature IDs to analyze
        feature_categories: Categories of the features
        top_k: Number of top activations to return per feature

    Returns:
        Dictionary mapping feature IDs to tuples of (top_indices, top_values)
    """
    _, cache = model.run_with_cache(images, names_filter=[sparse_autoencoder.cfg.hook_point])
    
    layer_activations = cache[sparse_autoencoder.cfg.hook_point]
    batch_size, seq_len, _ = layer_activations.shape
    flattened_activations = einops.rearrange(layer_activations, "batch seq d_mlp -> (batch seq) d_mlp")
    
    sae_input = flattened_activations - sparse_autoencoder.b_dec
    feature_activations = einops.einsum(sae_input, encoder_weights, "... d_in, d_in n -> ... n") + encoder_biases
    feature_activations = torch.nn.functional.relu(feature_activations)
    
    reshaped_activations = einops.rearrange(feature_activations, "(batch seq) d_in -> batch seq d_in", batch=batch_size, seq=seq_len)
    cls_token_activations = reshaped_activations[:, 0, :]
    mean_image_activations = reshaped_activations.mean(1)

    top_activations = {}
    for i, (feature_id, feature_category) in enumerate(zip(feature_ids, feature_categories)):
        if "CLS_" in feature_category:
            top_values, top_indices = cls_token_activations[:, i].topk(top_k)
        else:
            top_values, top_indices = mean_image_activations[:, i].topk(top_k)
        top_activations[feature_id] = (top_indices, top_values)
    
    return top_activations

def find_top_activations(
    val_dataloader: torch.utils.data.DataLoader,
    model: torch.nn.Module,
    sparse_autoencoder: torch.nn.Module,
    cfg: object,
    interesting_features_indices: List[int],
    interesting_features_category: List[str],
    top_k: int = 16,
) -> Dict[int, Tuple[torch.Tensor, torch.Tensor]]:
    """
    Find the top activations for interesting features across the validation dataset.

    Args:
        val_dataloader: Validation data loader
        model: The main model
        sparse_autoencoder: The sparse autoencoder
        cfg: Configuration object
        interesting_features_indices: Indices of interesting features
        interesting_features_category: Categories of interesting features

    Returns:
        Dictionary mapping feature IDs to tuples of (top_values, top_indices)
    """
    max_samples = cfg.eval_max

    top_activations = {i: (None, None) for i in interesting_features_indices}
    encoder_biases = sparse_autoencoder.b_enc[interesting_features_indices]
    encoder_weights = sparse_autoencoder.W_enc[:, interesting_features_indices]

    processed_samples = 0
    for batch_images, _, batch_indices in tqdm(val_dataloader, total=max_samples // cfg.batch_size):
        batch_images = batch_images.to(cfg.device)
        batch_indices = batch_indices.to(cfg.device)
        batch_size = batch_images.shape[0]

        batch_activations = compute_feature_activations(
            batch_images, model, sparse_autoencoder, encoder_weights, encoder_biases,
            interesting_features_indices, interesting_features_category, top_k
        )

        for feature_id in interesting_features_indices:
            new_indices, new_values = batch_activations[feature_id]
            new_indices = batch_indices[new_indices]
            
            if top_activations[feature_id][0] is None:
                top_activations[feature_id] = (new_values, new_indices)
            else:
                combined_values = torch.cat((top_activations[feature_id][0], new_values))
                combined_indices = torch.cat((top_activations[feature_id][1], new_indices))
                _, top_k_indices = torch.topk(combined_values, top_k)
                top_activations[feature_id] = (combined_values[top_k_indices], combined_indices[top_k_indices])

        processed_samples += batch_size
        if processed_samples >= max_samples:
            break

    return {i: (values.detach().cpu(), indices.detach().cpu()) 
            for i, (values, indices) in top_activations.items()}

# Usage
top_activations_per_feature = find_top_activations(
    val_dataloader, model, sparse_autoencoder, cfg,
    interesting_features_indices, interesting_features_category
)

  0%|          | 0/1562 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling

In [None]:
from vit_prisma.dataloaders.imagenet_dataset import get_imagenet_index_to_name
ind_to_name = get_imagenet_index_to_name()


In [None]:

torch.no_grad()
def get_heatmap(
          image,
          model,
          sparse_autoencoder,
          feature_id,
): 
    image = image.to(cfg.device)
    _, cache = model.run_with_cache(image.unsqueeze(0))

    post_reshaped = einops.rearrange(cache[sparse_autoencoder.cfg.hook_point], "batch seq d_mlp -> (batch seq) d_mlp")
    # Compute activations (not from a fwd pass, but explicitly, by taking only the feature we want)
    # This code is copied from the first part of the 'forward' method of the AutoEncoder class
    sae_in =  post_reshaped - sparse_autoencoder.b_dec # Remove decoder bias as per Anthropic
    acts = einops.einsum(
            sae_in,
            sparse_autoencoder.W_enc[:, feature_id],
            "x d_in, d_in -> x",
        )
    return acts 
     
def image_patch_heatmap(activation_values,image_size=224, pixel_num=14):
    activation_values = activation_values.detach().cpu().numpy()
    activation_values = activation_values[1:]
    activation_values = activation_values.reshape(pixel_num, pixel_num)

    # Create a heatmap overlay
    heatmap = np.zeros((image_size, image_size))
    patch_size = image_size // pixel_num

    for i in range(pixel_num):
        for j in range(pixel_num):
            heatmap[i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = activation_values[i, j]

    return heatmap

    # Removing axes


for feature_ids, cat, logfreq in tqdm(zip(top_activations_per_feature.keys(), interesting_features_category, interesting_features_values), total=len(interesting_features_category)):
  #  print(f"looking at {feature_ids}, {cat}")
    max_vals, max_inds = top_activations_per_feature[feature_ids]
    images = []
    model_images = []
    gt_labels = []
    for bid, v in zip(max_inds, max_vals):

        image, label, image_ind = val_data_visualize[bid]

        assert image_ind.item() == bid
        images.append(image)

        # model_img, _, _ = imagenet_data[bid]
        model_image, _, _ = val_data[bid]
        model_images.append(model_image)
        gt_labels.append(ind_to_name[str(label)][1])
    
    grid_size = int(np.ceil(np.sqrt(len(images))))
    fig, axs = plt.subplots(int(np.ceil(len(images)/grid_size)), grid_size, figsize=(15, 15))
    name=  f"Category: {cat},  Feature: {feature_ids}"
    fig.suptitle(name)#, y=0.95)
    for ax in axs.flatten():
        ax.axis('off')
    complete_bid = []

    for i, (image_tensor, label, val, bid,model_img) in enumerate(zip(images, gt_labels, max_vals,max_inds,model_images )):
        if bid in complete_bid:
            continue 
        complete_bid.append(bid)


        row = i // grid_size
        col = i % grid_size
        heatmap = get_heatmap(model_img,model,sparse_autoencoder, feature_ids )
        heatmap = image_patch_heatmap(heatmap, pixel_num=224//cfg.patch_size)

        display = image_tensor.numpy().transpose(1, 2, 0)

        has_zero = False
        

        axs[row, col].imshow(display)
        axs[row, col].imshow(heatmap, cmap='viridis', alpha=0.3)  # Overlaying the heatmap
        axs[row, col].set_title(f"{label} {val.item():0.03f} {'class token!' if has_zero else ''}")  
        axs[row, col].axis('off')  

    plt.tight_layout()

    folder = os.path.join(cfg.max_image_output_folder, f"{cat}")
    os.makedirs(folder, exist_ok=True)
    plt.savefig(os.path.join(folder, f"neglogfreq_{-logfreq}feauture_id:{feature_ids}.png"))
    plt.close()


NameError: name 'top_per_feature' is not defined

In [None]:
cfg.max_image_output_folder

'/network/scratch/s/sonia.joseph/sae_checkpoints/max_images/1f89d99e-wkcn-TinyCLIP-ViT-40M-32-Text-19M-LAION400M-expansion-16'