In [1]:
import sys
import os
from shared.utils import load_images, plot_images_with_max_per_row, plot_attn_maps, filter_highest_layer
from shared.algorithms import find_register_neurons
import numpy as np
import matplotlib.pyplot as plt
import torch
import yaml
from tqdm import tqdm

sys.path.append("clip/")
sys.path.append("dinov2/")


In [2]:
# Set cuda visible devices
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Change this to the GPU you want to use

### Load model

In [3]:
MODEL = "dinov2" # can be "clip" or "dinov2"

In [None]:
if MODEL == "dinov2":
  from dinov2_state import load_dinov2_state

  with open("configs/dinov2_large.yaml", "r") as f:
    config = yaml.safe_load(f)

  state = load_dinov2_state(config)
elif MODEL == "clip":
  from clip_state import load_clip_state

  with open("configs/openclip_base.yaml", "r") as f:
    config = yaml.safe_load(f)

  state = load_clip_state(config)
else:
  raise ValueError(f"Model {MODEL} not supported")


In [5]:
IMAGENET_PATH = "/datasets/ilsvrc/current/val" # Pass in path to ImageNet
IMAGE_SIZE = 224 # Preprocessed image size

run_model = state["run_model"]
model = state["model"]
preprocess = state["preprocess"] # Preprocess function for input images
hook_manager = state["hook_manager"]
num_layers = state["num_layers"]
num_heads = state["num_heads"]
patch_size = state["patch_size"]
config = state["config"]
patch_height = IMAGE_SIZE // patch_size
patch_width = IMAGE_SIZE // patch_size
device = "cuda:0"

### Run on one image

In [None]:
image = load_images(IMAGENET_PATH, count = 1)[0]

In [7]:
processed_image = preprocess(image).unsqueeze(0).to(device)
hook_manager.set_debug(True)
hook_manager.reinit()
hook_manager.finalize()
representation = run_model(model, processed_image)

attention_maps = hook_manager.get_attention_maps() # shape (L, H, N, N)
layer_outputs = hook_manager.get_layer_outputs() # shape (L, N, D)

patch_norms = np.linalg.norm(layer_outputs[:, 1:], axis=2).reshape(num_layers, patch_height, patch_width)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis('off')
plt.title("Input")
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Plot patch norms across all layers. Notice that outliers appear in the later layers.

plt = plot_images_with_max_per_row(patch_norms, max_per_row = 4, image_title = "Layer")
plt.tight_layout()
plt.show()

In [None]:
# Plot attention maps

cls_attn_maps = attention_maps[:, :, 0, 1:].reshape(num_layers, num_heads, patch_height, patch_width)
plt = plot_attn_maps(cls_attn_maps)
plt.show()


### Identify where outliers & attention sinks appear

In [None]:
rand_images = load_images(IMAGENET_PATH, count = 10)
max_patch_norms = [0] * num_layers
max_attn_norms = [0] * num_layers

hook_manager.set_debug(True)

for image in tqdm(rand_images, desc="Processing images"):
  processed_image = preprocess(image).unsqueeze(0).to(device)
  hook_manager.reinit()
  hook_manager.finalize()
  representation = run_model(model, processed_image)
  attention_maps = hook_manager.get_attention_maps() # shape (L, H, N, N)
  layer_outputs = hook_manager.get_layer_outputs() # shape (L, N, D)

  patch_norms = np.max(np.linalg.norm(layer_outputs[:, 1:patch_height * patch_width + 1], axis=2), axis=1)
  attn_norms = np.max(np.mean(attention_maps[:, :, 0, 1:patch_height * patch_width + 1], axis = 1), axis = 1)

  for j in range(num_layers):
    max_attn_norms[j] += attn_norms[j]
    max_patch_norms[j] += patch_norms[j]

max_attn_norms = [x / len(rand_images) for x in max_attn_norms]
max_patch_norms = [x / len(rand_images) for x in max_patch_norms]


In [None]:
import matplotlib.pyplot as plt

# Create a figure with two subplots side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 5), dpi=100)

# === Subplot 1: Patch Norms ===
ax1.plot(range(num_layers), max_patch_norms, marker='o', markersize=8, color='steelblue', linestyle='-', linewidth=2)
ax1.set_title('Average Max Patch Norms (Layer output)')
ax1.set_xlabel('Layer')
ax1.set_ylabel('Norm')
ax1.set_xticks(range(0, num_layers, 2))  # Show every second tick
ax1.tick_params(axis='both')
ax1.tick_params(axis='x', which='major', pad=15)
ax1.margins(x=0.1)
ax1.grid(True, linestyle='-')

# === Subplot 2: Attention Norms ===
ax2.plot(range(num_layers), max_attn_norms, marker='^', markersize=8, color='orange', linestyle='-', linewidth=2)
ax2.set_title('Average Max Attention (CLS)')
ax2.set_xlabel('Layer')
ax2.set_ylabel('Attention Value')
ax2.set_xticks(range(0, num_layers, 2))  # Show every second tick
ax2.tick_params(axis='both')
ax2.tick_params(axis='x', which='major', pad=15)
ax2.margins(x=0.1)
ax2.grid(True, linestyle='-')


### Identify register neurons

In [14]:
#########################################
#              PARAMETERS               #
#########################################

# layer used to detect outliers based on the patch norms. Set to the last layer (-1) for CLIP and second-to-last layer (-2) for DINOv2 large
detect_outliers_layer = config["detect_outliers_layer"]
register_norm_threshold = config["register_norm_threshold"] # threshold for detecting register neurons

In [None]:
register_neurons = find_register_neurons(
    model_state=state,
    image_path=IMAGENET_PATH,
    detect_outliers_layer=detect_outliers_layer,
    processed_image_cnt=100,
    register_norm_threshold=register_norm_threshold,
    apply_sparsity_filter=True,
)

Optionally, you can save the register neurons

In [7]:
torch.save(register_neurons, "register_neurons.pt")

In [None]:
register_neurons = torch.load("register_neurons.pt")

### Evaluate register neurons w/ test-time registers

In [17]:
#########################################
#              PARAMETERS               #
#########################################
top_k = config["top_k"]
highest_layer = config["highest_layer"]
num_registers = 1

In [None]:
filtered_register_neurons = filter_highest_layer(register_neurons, highest_layer)

neurons_to_ablate = dict()
for (layer, neuron, score) in filtered_register_neurons[:top_k]:
  if layer not in neurons_to_ablate:
    neurons_to_ablate[layer] = []
  neurons_to_ablate[layer].append(neuron)
print(neurons_to_ablate)


In [None]:
random_images = load_images(IMAGENET_PATH, count = 50)
image_norms = []
register_norms = []
image_attentions = []
register_attentions = []
hook_manager.set_debug(True)
for i in tqdm(range(len(random_images)), desc = "Processing random images"):
  image = preprocess(random_images[i]).unsqueeze(0).to(device)

  hook_manager.reinit()
  hook_manager.intervene_register_neurons(neurons_to_ablate=neurons_to_ablate, num_registers = num_registers, normal_values="zero", scale = 1)
  hook_manager.finalize()
  representation = run_model(model, image, num_registers = num_registers)
  attention_maps = hook_manager.get_attention_maps()
  layer_outputs = hook_manager.get_layer_outputs()

  layer_norms = np.linalg.norm(layer_outputs[detect_outliers_layer], axis=1)

  image_patch_norms = layer_norms[1:patch_height * patch_width + 1]
  register_patch_norms = layer_norms[patch_height * patch_width + 1:]

  image_norms.extend(image_patch_norms.tolist())
  register_norms.extend(register_patch_norms.tolist())

  image_attentions.append(np.max(attention_maps[-1, :, 0, 1:patch_height * patch_width + 1]))
  register_attentions.append(np.max(attention_maps[-1, :, 0, patch_height * patch_width + 1:]))

In [None]:
# Plot the image norms and register norms as histograms side by side
plt.figure(figsize=(10, 12))

# Plot image norms
plt.subplot(2, 2, 1)
plt.hist(image_norms, bins=50, alpha=0.7, color='blue')
plt.xlabel('Norm Value')
plt.ylabel('Frequency')
plt.title('Image Patch Norms')
plt.grid(True, alpha=0.3)
plt.axvline(x=np.median(image_norms), color='r', linestyle='--',
            label=f'Median: {np.median(image_norms):.2f}')
plt.legend()

# Plot register norms
plt.subplot(2, 2, 2)
plt.hist(register_norms, bins=50, alpha=0.7, color='green')
plt.xlabel('Norm Value')
plt.ylabel('Frequency')
plt.title('Register Patch Norms')
plt.grid(True, alpha=0.3)
plt.axvline(x=np.median(register_norms), color='r', linestyle='--',
            label=f'Median: {np.median(register_norms):.2f}')
plt.legend()

# Plot image attentions
plt.subplot(2, 2, 3)
plt.hist(image_attentions, bins=50, alpha=0.7, color='blue')
plt.xlabel('Attention Value')
plt.ylabel('Frequency')
plt.title('Image Patch Attentions')
plt.grid(True, alpha=0.3)
plt.axvline(x=np.median(image_attentions), color='r', linestyle='--',
            label=f'Median: {np.median(image_attentions):.2f}')
plt.legend()

# Plot register attentions
plt.subplot(2, 2, 4)
plt.hist(register_attentions, bins=50, alpha=0.7, color='green')
plt.xlabel('Attention Value')
plt.ylabel('Frequency')
plt.title('Register Patch Attentions')
plt.grid(True, alpha=0.3)
plt.axvline(x=np.median(register_attentions), color='r', linestyle='--',
            label=f'Median: {np.median(register_attentions):.2f}')
plt.legend()

plt.tight_layout()
plt.show()

# Print some statistics for comparison
print(f"Image norms - Min: {min(image_norms):.2f}, Max: {max(image_norms):.2f}, Mean: {np.mean(image_norms):.2f}")
print(f"Register norms - Min: {min(register_norms):.2f}, Max: {max(register_norms):.2f}, Mean: {np.mean(register_norms):.2f}")
print(f"Image attentions - Min: {min(image_attentions):.2f}, Max: {max(image_attentions):.2f}, Mean: {np.mean(image_attentions):.2f}")
print(f"Register attentions - Min: {min(register_attentions):.2f}, Max: {max(register_attentions):.2f}, Mean: {np.mean(register_attentions):.2f}")

In [None]:
image = load_images(IMAGENET_PATH, count = 1)[0]

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis('off')
plt.title("Input")
plt.show()

In [24]:
processed_image = preprocess(image).unsqueeze(0).to(device)

hook_manager.reinit()
hook_manager.finalize()
representation = run_model(model, processed_image)
original_attention_maps = hook_manager.get_attention_maps()
original_layer_outputs = hook_manager.get_layer_outputs()

hook_manager.reinit()
hook_manager.intervene_register_neurons(neurons_to_ablate=neurons_to_ablate, num_registers = num_registers)
hook_manager.finalize()
representation = run_model(model, processed_image, num_registers = num_registers)
ablated_attention_maps = hook_manager.get_attention_maps()
ablated_layer_outputs = hook_manager.get_layer_outputs()
ablated_neuron_activations = hook_manager.get_neuron_activations()

In [None]:
# Norm map of output patch embeddings - baseline and ablated comparison for all layers

# Import the necessary module for make_axes_locatable
from mpl_toolkits.axes_grid1 import make_axes_locatable

# Create a figure with subplots for each layer (1 row per layer)
fig, axs = plt.subplots(num_layers, 2, figsize=(16, 4 * num_layers))
fig.subplots_adjust(hspace=0.5, wspace=0.3)

# Plot norm maps for each layer
for layer in range(num_layers):
    # Calculate norms for baseline and ablated outputs

    # Calculate norms for outputs
    baseline_output_norms_flat = np.linalg.norm(original_layer_outputs[layer, 1:], axis=1)
    ablated_output_norms_flat = np.linalg.norm(ablated_layer_outputs[layer, 1:], axis=1)

    # Handle non-square reshaping
    def reshape_with_extras(flat_array, patch_height, patch_width):
        total_patches = len(flat_array)
        if total_patches == patch_height * patch_width:
            # Perfect square case
            return flat_array.reshape((patch_height, patch_width)), None
        else:
            # Non-square case
            square_part = flat_array[:patch_height * patch_width].reshape((patch_height, patch_width))
            extra_part = flat_array[patch_height * patch_width:]
            return square_part, extra_part

    # Reshape with handling for extra values
    baseline_output_norms, baseline_output_extras = reshape_with_extras(baseline_output_norms_flat, patch_height, patch_width)
    ablated_output_norms, ablated_output_extras = reshape_with_extras(ablated_output_norms_flat, patch_height, patch_width)

    # Plot baseline output
    im3 = axs[layer, 0].imshow(baseline_output_norms, cmap='viridis')
    extra_info = ""
    axs[layer, 0].set_title(f'Layer {layer} - Original (output){extra_info}', fontsize=14)
    axs[layer, 0].set_xlabel('Patch X', fontsize=12)
    axs[layer, 0].set_ylabel('Patch Y', fontsize=12)

    # Add colorbar for output original
    divider = make_axes_locatable(axs[layer, 0])
    cax = divider.append_axes("right", size="5%", pad=0.1)
    cbar = fig.colorbar(im3, cax=cax)
    cbar.set_label(f'Layer {layer} Output Norm', fontsize=12)

    # Plot ablated output
    im4 = axs[layer, 1].imshow(ablated_output_norms, cmap='viridis')
    extra_info = ""
    axs[layer, 1].set_title(f'Layer {layer} - Ablated (output){extra_info}', fontsize=14)
    axs[layer, 1].set_xlabel('Patch X', fontsize=12)
    axs[layer, 1].set_ylabel('Patch Y', fontsize=12)

    # Add colorbar for output ablated
    divider = make_axes_locatable(axs[layer, 1])
    cax = divider.append_axes("right", size="5%", pad=0.1)
    cbar = fig.colorbar(im4, cax=cax)
    cbar.set_label(f'Layer {layer} Output Norm', fontsize=12)

fig.suptitle('Norm Maps of Image Patches Across All Layers', fontsize=20, y=1.00)
plt.tight_layout()
plt.show()


In [None]:
# Original attention maps
plt = plot_attn_maps(original_attention_maps[:, :, 0, 1:patch_height * patch_width + 1].reshape((num_layers, num_heads, patch_height, patch_width)))
plt.show()

In [None]:
# Ablated attention maps
plt = plot_attn_maps(ablated_attention_maps[:, :, 0, 1:patch_height * patch_width + 1].reshape((num_layers, num_heads, patch_height, patch_width)))
plt.show()