# Distributions

We look at neuron activation distributions

## Get Activations
We first import dependancies and run the model to get some neuron distributions

In [None]:
try: # if in google colab, download necessary python files
  import google.colab 
  ! pip install -qq seperability
except ModuleNotFoundError:
  pass

In [None]:
import torch
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from seperability import Model
from seperability.activations import get_midlayer_activations

In [None]:
opt = Model('125m', limit=1000)
dataset = 'pile'

In [None]:
data = get_midlayer_activations( opt, dataset, 1e4, collect_ff=True, collect_attn=True )

ff_activations   = data["raw"]["ff"].permute( (1,2,0) )
attn_activations = data["raw"]["attn"].permute( (1,2,3,0) ).reshape( (opt.n_layers, opt.d_model, -1) )
print( ff_activations.size() )
print( attn_activations.size() )

## Plot Distributions for Neurons
We can investigate the distribution of some random neurons in the network.

In [None]:
def plot_activation_indices(activations, indices):
    for j in range(opt.n_layers):
        for i in indices:
            counts, bins = np.histogram( activations[j][i].cpu().numpy(), bins=100 )
            mids = (bins[:-1] + bins[1:]) / 2
            plt.semilogy( mids, counts )
        plt.show()

In [None]:
plot_activation_indices( ff_activations, [0, 10, 100, 200, 300, 400, 500 ] )

In [None]:

plot_activation_indices( attn_activations, [0, 10, 100, 200, 300, 400, 500 ] )

In [None]:
attn_means = torch.mean( attn_activations, dim=-1)
attn_indices = torch.sort( attn_means, dim=-1 ).indices.cpu().numpy()
indices = set([])
for i in range(opt.n_layers):
    indices.add( attn_indices[i][-1] )
    indices.add( attn_indices[i][-2] )
indices = np.sort( np.array(list(indices)) )
n_items = len(indices)

In [None]:
for j in range(opt.n_layers):
    for i in indices:
        rgb = mpl.colors.hsv_to_rgb( (1, 1, i/n_items) )
        print( rgb )
        counts, bins = np.histogram( ff_activations[j][i].cpu().numpy(), bins=50 )
        mids = (bins[:-1] + bins[1:]) / 2
        plt.semilogy( mids, counts, label=f" dim {i}", color = rgb )
    plt.legend()
    plt.show()