# Peak Comparisons
Comparing peak locations with mean locations and zero

## Setup

In [47]:
import torch
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from taker import Model
from taker.activations import get_midlayer_activations

In [48]:
def get_mean_offsets(activations):
    # Check if the activations tensor is of type torch.float16
    if activations.dtype == torch.float16:
        # Convert to torch.float32 for mode calculation
        activations_float32 = activations.float()
    else:
        # Use the original tensor if it's already in a supported data type
        activations_float32 = activations

    # Compute the mode across the last dimension for each neuron in every layer
    #mode_values_float32, _ = torch.mode(activations_float32, dim=-1)
    mode_values_float32 = torch.mean(activations_float32, dim=-1)
    
    # If the original tensor was torch.float16, convert the result back to torch.float16
    if activations.dtype == torch.float16:
        mode_values = mode_values_float32.half()
    else:
        mode_values = mode_values_float32
    
    # The mode_values tensor will have shape [layers, neurons], which is already 2D
    # and matches the requirement of returning a 2D tensor of mode values.
    
    return mode_values

def get_bucket_peaks(activations):
    # Check if the activations tensor is of type torch.float16
    if activations.dtype == torch.float16:
        # Convert to torch.float32 for histogram calculation
        activations_float32 = activations.float()
    else:
        # Use the original tensor if it's already in a supported data type
        activations_float32 = activations

    # Prepare for histogram computation
    min_val = activations_float32.min()
    max_val = activations_float32.max()
    bins = 100

    # Initialize an empty tensor to hold the peak values
    peak_values_float32 = torch.empty(activations_float32.size()[:-1], device=activations_float32.device, dtype=torch.float32)
    
    # Compute the histogram and find the peak for each neuron in every layer
    for i in range(activations_float32.size()[0]):  # Assuming the first dimension is layers
        for j in range(activations_float32.size()[1]):  # Assuming the second dimension is neurons
            #min_val = activations_float32[i, j].min()
            #max_val = activations_float32[i, j].max()

            hist = torch.histc(activations_float32[i, j], bins=bins, min=min_val, max=max_val)
            peak_bin = hist.argmax()
            # Compute the center value of the peak bin
            bin_width = (max_val - min_val) / bins
            peak_value = min_val + bin_width * (peak_bin.float() + 0.5)
            peak_values_float32[i, j] = peak_value

    # If the original tensor was torch.float16, convert the result back to torch.float16
    if activations.dtype == torch.float16:
        peak_values = peak_values_float32.half()
    else:
        peak_values = peak_values_float32

    return peak_values

def get_kde_peaks(activations, bandwidth=0.1):
    layers, neurons, _ = activations.shape  # Assuming activations is a 3D tensor of shape [layers, neurons, activations]

    # Initialize an empty tensor for main peak values with shape [layers, neurons]
    main_peak_values = torch.empty((layers, neurons), dtype=torch.float32)

    # Ensure activations are in float32 for KDE
    if activations.dtype == torch.float16:
        activations_float32 = activations.float()
    else:
        activations_float32 = activations

    # Iterate over each layer and neuron to compute the main peak value
    for layer in range(layers):
        print(f"Calculating KDE for layer {layer+1} of {layers}")
        for neuron in range(neurons):
            # Convert activations to numpy for KDE computation
            activations_np = activations_float32[layer, neuron].cpu().numpy().flatten()

            # Perform Kernel Density Estimation
            kde = stats.gaussian_kde(activations_np, bw_method=bandwidth)
            
            # Evaluate the KDE on a fine grid to find the peak
            grid = np.linspace(activations_np.min(), activations_np.max(), 1000)
            kde_values = kde.evaluate(grid)
            
            # Identify the main peak as the grid value with the highest KDE estimate
            main_peak_value = grid[np.argmax(kde_values)]

            # Store the main peak value
            main_peak_values[layer, neuron] = main_peak_value

    # No need to adjust activations here; just return the 2D tensor of main peak values
    return main_peak_values

def get_mode_offsets(activations):
    # Check if the activations tensor is of type torch.float16
    if activations.dtype == torch.float16:
        # Convert to torch.float32 for mode calculation
        activations_float32 = activations.float()
    else:
        # Use the original tensor if it's already in a supported data type
        activations_float32 = activations

    # Compute the mode across the last dimension for each neuron in every layer
    mode_values_float32, _ = torch.mode(activations_float32, dim=-1)
    #mode_values_float32 = torch.mean(activations_float32, dim=-1)
    
    # If the original tensor was torch.float16, convert the result back to torch.float16
    if activations.dtype == torch.float16:
        mode_values = mode_values_float32.half()
    else:
        mode_values = mode_values_float32
    
    # The mode_values tensor will have shape [layers, neurons], which is already 2D
    # and matches the requirement of returning a 2D tensor of mode values.
    
    #flip the sign on everything since we are doing addition in the offset mask
    #return mode_values * -1
    return mode_values

## Get Data

In [49]:
#opt = Model('facebook/opt-125m', limit=1000)
#dataset = 'pile'

opt = Model('google/vit-base-patch16-224', limit=1000, dtype='fp32')
dataset = 'imagenet-1k-birdless'

- Loaded facebook/opt-125m
 - Registered 12 Attention Layers


In [50]:

data = get_midlayer_activations( opt, dataset, 1e4, collect_ff=False, collect_attn=True )
# [token, layer, neuron] -> [layer, neuron, token]
#ff_activations   = data.raw["ff"].permute( (1,2,0) )
# [token, layer, attention head, attention neuron] -> [layer, attention head, attention neuron, token]
attn_activations = data.raw["attn"].permute( (1,2,3,0) ).reshape( (opt.cfg.n_layers, opt.cfg.d_model, -1) )


Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

10311it [00:14, 690.10it/s]                            


In [51]:

means = get_mean_offsets(attn_activations)
peaks = get_bucket_peaks(attn_activations)

#get the distance between the mean and the peak
mean_peak_diff = torch.abs(means - peaks)

#distance of means from 0
peak_0_diff = torch.abs(means)

#distance of peaks from 0
mean_0_diff = torch.abs(peaks)

In [52]:
#print sums of all of the differences
print(f"mean_peak_diff: {mean_peak_diff.sum()}")
print(f"peak_0_diff: {peak_0_diff.sum()}")
print(f"mean_0_diff: {mean_0_diff.sum()}")

mean_peak_diff: 98.4375
peak_0_diff: 112.0625
mean_0_diff: 45.75


In [53]:
#TODO: how often is the peak at 0 (or really close)
def num_peaks_at_zero(peaks, tolerance=0.001):
    # Count how many numbers are within the tolerance of the target value
    close_to_target = torch.abs(peaks) <= tolerance
    return close_to_target.sum()

print(f"peaks at 0: {num_peaks_at_zero(peaks)}")
print(f"means at 0: {num_peaks_at_zero(means)}")

#print size of peaks
print(f"peaks size: {peaks.size()}")

# print 100 peaks
print(peaks[3, :100])

#FIXME: are all of the peaks are basically the same because of how we are doing the bucketing?


peaks at 0: 0
means at 0: 1037
peaks size: torch.Size([12, 768])
tensor([0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021,
        0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021,
        0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021,
        0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0427, 0.0021, 0.0021,
        0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021,
        0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021,
        0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021,
        0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0834,
        0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021,
        0.0021, 0.0021, 0.0427, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021,
        0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021, 0.0021,
        0.0021], dtype=torch.float16)
