# Exploring activations

In [1]:
import numpy as np
import torch
import pickle
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import Counter


## Preparing "training" data

Load desired benchmark activations :

In [2]:
activation_paths = '../activations/'
benchmark = 'MHalu'
model = 'llava_ov_7b'

path = os.path.join(activation_paths, benchmark + '+' + model)

In [3]:
#Load the train
with open(path + '/train_activations.pkl', 'rb') as f:
    activations = pickle.load(f)

with open(path + '/train_classes.pkl', 'rb') as f:
    labels_to_indices = pickle.load(f)

#get reverse dictionnary (index to class)
indices_to_labels = dict()
for key, values in labels_to_indices.items():
    for value in values:
        indices_to_labels[value] = key

Get the centroids:


In [4]:
centroids = dict()
for label in labels_to_indices.keys():
    centroids[label] = torch.zeros_like(activations[0]) #initialize null
    
    # go through associated indices (list of indices)
    for i in labels_to_indices[label]:
        centroids[label] += activations[i]
    #average
    centroids[label] /= len(labels_to_indices[label])

In [5]:
int_to_str = dict()
str_to_int = dict()
# shape [num_classes, num_heads, dim_heads]
class_activations = torch.zeros([len(centroids)] + list(activations[0].shape))

for i, v in enumerate(centroids.keys()):
    int_to_str[i] = v
    str_to_int[v] = i
    class_activations[i] = centroids[v]
class_activations.shape

torch.Size([2, 784, 128])

## Selecting top heads
Closely following the base implementation

In [6]:
def record_head_performance_base(class_activations, cur_activation, label, success_count):
    """
    sample_activations: (num_sample, num_head, hidden_dim)
    cur_activation: (num_head, hidden_dim)
    success_count is dynamically updated
    """
    #TODO change similarity here
    all_sample = []

    for i in range(class_activations.shape[1]):
        scores = torch.nn.functional.cosine_similarity(class_activations[:, i, :], cur_activation[i, :], dim=-1)
        all_sample.append(scores.argmax(dim=0).item())
    for idx in range(len(all_sample)):
        if all_sample[idx] == label:
            success_count[idx] += 1

Count :

In [7]:
success_count = [0 for _ in range(class_activations.shape[1])]

#go through training data
for index, activation in tqdm(activations.items()):
    int_label = str_to_int[indices_to_labels[index]]
    record_head_performance_base(class_activations, activation, int_label, success_count)

arr = np.array(success_count)

100%|██████████| 40/40 [00:00<00:00, 71.24it/s]


Take $k$ best heads:

In [8]:
#k = num_head
k = 15 #How many top heads we want

topk_indices = np.sort(np.argsort(arr)[-k:][::-1])
topk_indices

array([  1,   4, 110, 147, 321, 322, 501, 525, 557, 563, 581, 584, 585,
       756, 776])

## Evaluate "test" set

First, some useful functions

In [9]:
def get_top_heads(all_heads, topk_indices):
    #assuming topk_indices sorted
    k = len(topk_indices)
    if len(all_heads.shape) == 3:
        top_heads = torch.zeros((all_heads.shape[0], k, all_heads.shape[2]))
        for i, k in enumerate(topk_indices):
            top_heads[:, i, :] = all_heads[:, k, :]
        return top_heads

    elif len(all_heads.shape) != 2:
        raise ValueError("Unrecognized shape for activations")
        
    top_heads = torch.zeros((k, all_heads.shape[1]))
    for i, k in enumerate(topk_indices):
        top_heads[i, :] = all_heads[k, :]
    return top_heads

In [10]:
top_class_activations = get_top_heads(class_activations, topk_indices)
top_class_activations.shape

torch.Size([2, 15, 128])

In [11]:
def retrieve_examples(sample_activations, cur_activation):
    """sample_activations = class_activations limited to the top heads"""
    all_sample = []
    num_heads = cur_activation.shape[0]

    for i in range(num_heads):
        scores = torch.nn.functional.cosine_similarity(
            sample_activations[:, i, :],  # (num_samples, hidden_dim)
            cur_activation[i, :],         # (hidden_dim,)
            dim=-1
        )
        all_sample.append(scores.argmax(dim=0).item())

    counter = Counter(all_sample)
    most_common = counter.most_common()

    chosen_examples = [item[0] for item in most_common]
    return chosen_examples

## Load test data

In [12]:
#Get number of test chunks
number_of_chunks = 0
for file in os.listdir(path):
    if 'test_activations' in file:
        number_of_chunks += 1

#load ground_truth once and for all
with open(path + '/test_classes.pkl', 'rb') as f:
    test_labels_to_indices = pickle.load(f)


## Infer chunk by chunk

In [13]:
# go through the test data by chunk :
test = {}
for i in range(number_of_chunks):
    # load chunk
    with open(path + f'/test_activations_{i}.pkl', 'rb') as f:
        test_chunk = pickle.load(f)
        test.update(test_chunk)


In [14]:
test_size = len(test)
results = np.zeros(test_size)

In [15]:
if benchmark == 'natural_ret':
    # Compute binary predictions for each sample using the saved activations
    results = np.zeros(test_size)  # 1D array to hold binary correctness per sample
    for index, activations in test.items():
        top_heads = get_top_heads(activations, topk_indices)  # only the heads we want
        preds = retrieve_examples(top_class_activations, top_heads)
        pred = preds[0]
        str_class = int_to_str[pred]  # convert integer prediction to string label
        # Set 1 if the current sample index is in the set of indices for the predicted label, else 0
        results[index] = int(index in test_labels_to_indices[str_class])
    
    # Reshape results into groups of 4 for further aggregated metrics
    group_results = results.reshape(-1, 4)
    total_groups = group_results.shape[0]
    
    # Raw accuracy: average correctness across all samples
    raw_acc = results.mean()
    raw_std = results.std()
    
    # Question accuracy:
    # For each group, the first question is correct if both sample 0 and 1 are correct,
    # and the second question is correct if both sample 2 and 3 are correct.
    q_first = group_results[:, 0] * group_results[:, 1]
    q_second = group_results[:, 2] * group_results[:, 3]
    q_correct = np.sum(q_first) + np.sum(q_second)
    q_acc = q_correct / (total_groups * 2)  # Two questions per group
    # Create a combined array of question outcomes per group (0 or 1 per question)
    q_outcomes = np.concatenate((q_first, q_second))
    q_std = q_outcomes.std()
    
    # Image accuracy:
    # First image is correct if samples 0 and 2 are correct;
    # Second image is correct if samples 1 and 3 are correct.
    i_first = group_results[:, 0] * group_results[:, 2]
    i_second = group_results[:, 1] * group_results[:, 3]
    i_correct = np.sum(i_first) + np.sum(i_second)
    i_acc = i_correct / (total_groups * 2)  # Two images per group
    # Create a combined array of image outcomes per group (0 or 1 per image)
    i_outcomes = np.concatenate((i_first, i_second))
    i_std = i_outcomes.std()
    
    # Group accuracy: a group is correct if all four samples are correct.
    g_outcomes = np.all(group_results == 1, axis=1).astype(float)
    g_acc = g_outcomes.mean()
    g_std = g_outcomes.std()
    
    # Print the metrics with standard deviations
    print(f"Raw Accuracy: {raw_acc:.4f}") #± {raw_std:.4f}")
    print(f"Question Accuracy: {q_acc:.4f}") #± {q_std:.4f}")
    print(f"Image Accuracy: {i_acc:.4f}") #± {i_std:.4f}")
    print(f"Group Accuracy: {g_acc:.4f}") #± {g_std:.4f}")



else:
    results = np.zeros(test_size) # will hold results, we just need one dim
    
    for index, activations in test.items():
        top_heads = get_top_heads(activations, topk_indices)        #only the heads we want
        preds = retrieve_examples(top_class_activations, top_heads) 
        pred = preds[0]
        str_class = int_to_str[pred]                              #string prediction
        results[index] = int(index in test_labels_to_indices[str_class])  # 0 if the index isn't in the predicted class' indexes
    
    accuracy = results.mean()
    std = results.std()
    print(accuracy)
    print(std)

0.8123280132085856
0.390450012376972
