# Distributions

We look at neuron activation distributions

## Get Activations
We first import dependancies and run the model to get some neuron distributions

In [None]:
try: # if in google colab, download necessary python files
  import google.colab 
  ! pip install -qq separability
except ModuleNotFoundError:
  pass

In [None]:
import torch
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from separability import Model
from separability.activations import get_midlayer_activations

In [None]:
opt = Model('nickypro/tinyllama-15m', limit=1000, dtype="fp32")
dataset = 'stories'

In [None]:
data = get_midlayer_activations( opt, dataset, 1e4, collect_ff=True, collect_attn=True )

ff_activations   = data.raw["ff"].permute( (1,2,0) )
attn_activations = data.raw["attn"].permute( (1,2,3,0) ).reshape( (opt.cfg.n_layers, opt.cfg.d_model, -1) ).clone()
print( ff_activations.size() )
print( attn_activations.size() )

## Plot Distributions for Neurons
We can investigate the distribution of some random neurons in the network.

In [None]:
def plot_layer_index(activations, layer, indices, ax=None, fill=False, n_bins=100):
    if ax is None:
        fig, ax = plt.subplots()
    for i in indices:
        label = None if fill else f"L{layer} Pos {i}"
        counts, bins = np.histogram( activations[layer][i].cpu().numpy(), bins=n_bins )
        mids = (bins[:-1] + bins[1:]) / 2
        if fill:
            ax.semilogy( mids, counts, label=label, alpha=0.2, linewidth=0.5 )
            ax.fill_between(mids, counts, color='skyblue', alpha=0.02)
        else:
            ax.semilogy( mids, counts, label=label, alpha=1, linewidth=1 )

def plot_activation_indices(activations, indices):
    for j in range(0, opt.n_layers, 7):
        fig, ax = plt.subplots()
        plt.title(f"layer {j}")
        plot_layer_index(activations, j, indices)
        plt.show()

def plot_multiple(activations, layer, indices, labels, xlim, ylim, fill=False):
    n_plots = len(activations)
    fig, axes = plt.subplots(1, n_plots, figsize=(10, 4))
    axes = axes if isinstance(axes, np.ndarray) else [axes]
    axes[0].set_ylabel(f"Unnormalized Probability Density")
    for i, ax in enumerate(axes):
        ax.set_xlabel(f"Neuron value in {labels[i]}")
        plot_layer_index(activations[i], layer, indices, ax, fill)
        ax.semilogy([0, 0], ylim, ":k", alpha=0.01)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        if not fill:
            ax.legend()
    plt.show()
        
    
    


In [None]:
attn_activations = data.raw["attn"].permute( (1,2,3,0) ).reshape( (opt.cfg.n_layers, opt.cfg.d_model, -1) ).clone()

In [None]:
print(attn_activations.shape)
zero_threshold = 1.0
zero_ratio     = 1.0
all_stds    = attn_activations.std(dim=-1)
mean_stds   = all_stds.mean(dim=-1)
__norm      = 1/mean_stds.unsqueeze(-1).unsqueeze(-1)
n_zeros     = (attn_activations.abs()*__norm <  zero_threshold ).sum(dim=-1)
n_non_zeros = (attn_activations.abs()*__norm >= zero_threshold ).sum(dim=-1)
print(n_zeros.shape, n_non_zeros.shape)
zeroness_score = n_zeros / n_non_zeros
zeroed_activations = torch.ones_like(zeroness_score, dtype=bool) * (zeroness_score > zero_ratio)
non_zeroed_activations = torch.logical_not(zeroed_activations)

attn_zeroed     = attn_activations *     zeroed_activations.unsqueeze(dim=-1)
attn_not_zeroed = attn_activations * non_zeroed_activations.unsqueeze(dim=-1)

[n_layers, d_attn] = n_zeros.shape

from scipy.signal import find_peaks

def find_range(xs, x_0):
    l, r = x_0, x_0
    while l > 0 and xs[l] > 0:
        l -= 1
    while r < len(xs) and xs[r] > 0:
        r += 1
    return l, r

for l in range(n_layers):
    std   = attn_activations[l].std(dim=-1).mean(dim=-1).item()
    print(f"Layer {l+1} (std {std}):")
    plot_multiple([attn_zeroed, attn_not_zeroed],
                l, range(d_attn), ["zeroed", "not_zeroed"], [-0.5, 0.5], [0.9, 3e3], True) 
    plt.show()
    
    peak_threshold = 1.0
    fig, ax = plt.subplots()
    all_peaks = []
    for n in range(d_attn):
        acts = attn_activations[l,n]
        #hist = torch.histc(acts, bins=100)
        counts, bins = hist = np.histogram(acts.cpu().numpy(), bins=50 )
        mids = ( bins[:-1] + bins[1:] )/2
        peaks, properties = find_peaks(counts, height=200, distance=40)
        peak_idx = peaks[np.argmin(np.abs(mids[peaks]))]
        
        close_acts = (counts*2 > counts[peak_idx])
        min_idx, max_idx = find_range(close_acts, peak_idx)
        if (max_idx - min_idx) > 20:
            peak_idx = (min_idx + max_idx) // 2
        
        peak_pos = mids[peak_idx]
            
        #if len(peaks) > 1:
        #    print(l, n, peaks.shape, peaks)
            
        if len(peaks) > 1 or peak_pos/std > peak_threshold:
            plot_layer_index([[acts]], 0, [0], ax=ax)
            ax.plot(mids, close_acts, ":k")
            ax.plot(mids[peaks], counts[peaks], "o", label=f"layer {l} peak {n}")
        all_peaks.append(mids[peak_idx])
    plt.show()

    # Plot peaks based graphs
    all_peaks = torch.tensor(np.array(all_peaks))
    zero_peak_criteria  = all_peaks.abs().unsqueeze(dim=-1) / std < peak_threshold
    attn_zero_peaks     = attn_activations[l] * zero_peak_criteria
    attn_non_zero_peaks = attn_activations[l] * torch.logical_not(zero_peak_criteria)
    attn_centered       = std * (attn_activations[l] - torch.tensor(all_peaks).unsqueeze(dim=-1)) / (all_stds[l].unsqueeze(dim=-1))
    attn_mean_centered  = std * (attn_activations[l] - attn_activations[l].mean(dim=-1).unsqueeze(dim=-1)) / (all_stds[l].unsqueeze(dim=-1))
    plot_multiple([[attn_zero_peaks], [attn_non_zero_peaks], [attn_centered], [attn_mean_centered]],
        0, range(d_attn), ["zeroed", "not_zeroed", "peak centered", "mean centered"], [-0.5, 0.5], [0.9e1, 3e3], True)
    