|<h2>Substack post:</h2>|<h1><a href="https://mikexcohen.substack.com/p/exploring-the-reve-eeg-transformer" target="_blank">Exploring the REVE EEG model</a></h1>|
|-|:-:|
|<h2>Teacher:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the post may lead to confusion or errors.</i>

In [None]:
# reference for model: https://brain-bzh.github.io/reve/
# Hugging Face model page: https://huggingface.co/brain-bzh/reve-base

In [None]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import gaussian_kde

from einops import rearrange

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import torch
import torch.nn.functional as F

from transformers import AutoModel
from datasets import load_dataset

In [None]:
### matplotlib adjustments

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    'figure.facecolor': '#191919', # 191919 for substack, 282a2c otherwise
    'figure.edgecolor': '#191919',
    'axes.facecolor':   '#191919',
    'axes.edgecolor':   '#DDE2F4',
    'axes.labelcolor':  '#DDE2F4',
    'xtick.color':      '#DDE2F4',
    'ytick.color':      '#DDE2F4',
    'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'savefig.dpi':300
})

# **Part 1: Import the model and sample dataset**

In [None]:
# login with secret token
import os
from huggingface_hub import login
login(os.environ.get('HF_TOKEN'))

In [None]:
pos_bank = AutoModel.from_pretrained('brain-bzh/reve-positions', trust_remote_code=True)
model = AutoModel.from_pretrained('brain-bzh/reve-base', trust_remote_code=True)
model.eval()

In [None]:
model.config

In [None]:
# check order of operations in forward pass
import inspect
print(inspect.getsource(model.__class__.forward))

In [None]:
# some helpful variables
n_layers = model.config.depth
emb_dims = model.config.embed_dim
n_heads = model.config.heads
head_dim = emb_dims // n_heads
sqrtD = head_dim**.5

print(f'There are {n_layers} layers,')
print(f'embeddings dimensionality of {emb_dims}, and')
print(f'{n_heads} heads, each with {head_dim} dimensions.')

In [None]:
# import some data
dataset = load_dataset('brain-bzh/eegmat-prepro',split='test')
dataset.set_format('torch',columns=['data','labels'])

positions = pos_bank(['Fp1','Fp2','F3','F4','F7','F8','T3','T4','C3','C4','T5','T6','P3','P4','O1','O2','Fz','Cz','Pz','A2']).unsqueeze(0)

timevec = np.arange(1000)/200 # 1k time points with srate=200

In [None]:
dataset

In [None]:
positions.shape

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')

sc = ax.scatter(positions[0,:,0],positions[0,:,1],positions[0,:,2],
                c=positions[0,:,2], cmap='Reds',alpha=0.6, edgecolors='w',
                marker='o', s=100, linewidth=.5)

ax.set(xlabel='X Position',ylabel='Y Position',zlabel='Z Position',title='EEG Electrode Locations')
ax.view_init(elev=20,azim=45)

plt.tight_layout()
plt.show()

In [None]:
epoch_idx = 0
X = dataset[epoch_idx]['data']

fig,axs = plt.subplots(1,2,figsize=(12,4))

cmin,cmax = np.percentile(X,[1,99])
axs[0].imshow(X,aspect='auto',cmap='plasma',vmin=cmin,vmax=cmax,
              extent=[timevec[0],timevec[-1],0,20])
axs[0].set(xlabel='Time (sec.)',ylabel='Channel index',yticks=range(0,20,2))

axs[1].plot(timevec,X.T + torch.linspace(0,1/X.std(),20))
axs[1].set(xlim=timevec[[0,-1]],xlabel='Time (sec.)',ylabel='Channel',yticks=[])

fig.suptitle(f'Data from epoch {epoch_idx}',fontweight='bold')

plt.tight_layout()
plt.show()

# **Part 2: Embeddings**

In [None]:
# parameters used in step 1
model.config.patch_size, model.config.patch_overlap

In [None]:
# Step 1: EEG time series to patches
X = dataset[epoch_idx]['data'].unsqueeze(0) # one epoch, unsquozen (sp?) to have a batch dim

# reshape to segment x time
patches = X.unfold(dimension = 2,
                   size = model.config.patch_size,
                   step = model.config.patch_size-model.config.patch_overlap)

# then expand to embeddings
patch_emb = rearrange(model.to_patch_embedding(patches),
          "b c h e -> b (c h) e", # combine channel and segments
          c = patches.shape[1], # M channels
          h = patches.shape[2],
          e = model.config.embed_dim
          )

print(f'   Data size: {list(X.shape)}')
print(f'Patches size: {list(patches.shape)}')
print(f'Patch-embeds: {list(patch_emb.shape)}')

In [None]:
fig,axs = plt.subplots(1,2,figsize=(12,3.3))

# extract the weights matrix
W = model.to_patch_embedding[0].weight.detach().numpy()
print(W.shape)

cmin,cmax = np.percentile(W,[10,90])
axs[0].imshow(W,aspect='auto',vmin=cmin,vmax=cmax,cmap='plasma')
axs[0].set(xlabel='Time steps (indices)',ylabel='Embeddings dimension',title='Weights matrix')

axs[1].plot(W[range(0,511,100),:].T+np.linspace(0,1,6)[None,:])
axs[1].set(xlim=[0,W.shape[1]],xlabel='Time steps (indices)',yticks=[],title='A few weights')

plt.tight_layout()
plt.show()

In [None]:
# Step 2: Expand electrode positions to include a time index
pos = model.fourier4d.add_time_patch(positions,patches.shape[2])
print(f'Positions size: {list(positions.shape)}')
print(f'Pos. embd size: {list(pos.shape)}')

In [None]:
# added column is time segment index
pos

In [None]:
# Step 3: Final adjusted position encoding
pos_expand = model.mlp4d(pos) # (mlp = multi-layer perceptron, aka feedforward network)
pos_fourier = model.fourier4d(pos)
pos_embed = model.ln(pos_expand + pos_fourier)

print(f'   Step 3a size: {list(pos_expand.shape)}')
print(f'   Step 3b size: {list(pos_fourier.shape)}')
print(f'Embeddings size: {list(pos_embed.shape)}')

In [None]:
fig,axs = plt.subplots(1,3,figsize=(10,4))

axs[0].imshow(pos_expand[0,:,:].detach().T,aspect='auto',vmin=-.5,vmax=.5,cmap='plasma')
axs[0].set(xlabel='Tokens (index)',ylabel='Embeddings dim',title='Position expansion')

axs[1].imshow(pos_fourier[0,:,:].detach().T,aspect='auto',vmin=-.5,vmax=.5,cmap='plasma')
axs[1].set(xlabel='Tokens (index)',ylabel='Embeddings dim',title='Position Fourier')

axs[2].imshow(pos_embed[0,:,:].detach().T,aspect='auto',vmin=-.5,vmax=.5,cmap='plasma')
axs[2].set(xlabel='Tokens (index)',ylabel='Embeddings dim',title='Position embeddings')

plt.tight_layout()
plt.show()

In [None]:
# Step 4: Combine time series and position embeddings

#          EEG   + channels
data = patch_emb + pos_embed


fig,axs = plt.subplots(1,3,figsize=(10,4))

axs[0].imshow(patch_emb[0,:,:].detach().T,aspect='auto',vmin=-.5,vmax=.5,cmap='plasma')
axs[0].set(xlabel='Tokens (index)',ylabel='Embeddings dim',title='Data embeddings')

axs[1].imshow(pos_embed[0,:,:].detach().T,aspect='auto',vmin=-.5,vmax=.5,cmap='plasma')
axs[1].set(xlabel='Tokens (index)',ylabel='Embeddings dim',title='Electrode embeddings')

axs[2].imshow(data[0,:,:].detach().T,aspect='auto',vmin=-.5,vmax=.5,cmap='plasma')
axs[2].set(xlabel='Tokens (index)',ylabel='Embeddings dim',title='Data input to model')

plt.tight_layout()
plt.show()

# **Part 3: Attention weights**

In [None]:
whichlayer = 10

# extract the wide weights matrix for this layer
wide_weights = model.transformer.layers[whichlayer][0].to_qkv.weight.detach().T

cmin,cmax = np.percentile(wide_weights,[10,90])

plt.figure(figsize=(10,3))
plt.imshow(wide_weights,vmin=cmin,vmax=cmax,cmap='plasma')
plt.axvline(emb_dims,linestyle='--',color='w')
plt.axvline(2*emb_dims,linestyle='--',color='w')
plt.colorbar(pad=.01)

plt.gca().set(xticks=[],ylabel='Embeddings dimensions',
              xlabel=' Queries dimensions         |           Keys dimensions           |           Values dimensions ',
              title=f'Attention weights from layer {whichlayer} / {n_layers}')


plt.tight_layout()
plt.show()

In [None]:
# split the Q, K, and V matrices
q,k,v = torch.split(wide_weights,emb_dims,dim=1)

# histograms of the three weights values
plt.figure(figsize=(8,3))
y,x = np.histogram(q.flatten(),bins='fd')
plt.plot(x[:-1],y,label='$\\mathbf{W_Q}$')

y,x = np.histogram(k.flatten(),bins='fd')
plt.plot(x[:-1],y,label='$\\mathbf{W_K}$')

y,x = np.histogram(v.flatten(),bins='fd')
plt.plot(x[:-1],y,label='$\\mathbf{W_V}$')

plt.gca().set(xlabel='Weight value',ylabel='Count',
              title=f'Distribution of QKV weights in layer {whichlayer}')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# common histogram boundaries
histedges = np.linspace(-.5,.5,81)

# initializations
distributions = np.zeros((n_layers,len(histedges)-1,3))
distchars = np.zeros((n_layers,3,3))

# loop over layers
for layeri in range(n_layers):

  # split into matrices
  wideW = model.transformer.layers[layeri][0].to_qkv.weight.detach().T
  q,k,v = torch.split(wideW,emb_dims,dim=1)

  # histograms
  distributions[layeri,:,0] = np.histogram(q,bins=histedges,density=True)[0]
  distributions[layeri,:,1] = np.histogram(k,bins=histedges,density=True)[0]
  distributions[layeri,:,2] = np.histogram(v,bins=histedges,density=True)[0]

  # mean and std
  distchars[layeri,:,0] = np.array([q.mean(), q.abs().mean(), q.var()])
  distchars[layeri,:,1] = np.array([k.mean(), k.abs().mean(), k.var()])
  distchars[layeri,:,2] = np.array([v.mean(), v.abs().mean(), v.var()])

# show the heatmaps
_,axs = plt.subplots(2,3,figsize=(12,6))
for i in range(3):
  axs[0,i].imshow(distributions[:,:,i],origin='lower',extent=[histedges[0],histedges[-1],0,n_layers],
                aspect='auto',cmap=plt.cm.plasma,vmin=0,vmax=3.5)
  axs[0,i].set(xlabel='Weight value',ylabel='Layer',title=f"$\\mathbf{{W}}_{'QKV'[i]}$")

plt.suptitle(f'Laminar distributions of attention weights',fontweight='bold')


for i in [0,1,2]:
  axs[1,i].plot(distchars[:,i,0],'gs-',markerfacecolor=[.7,.9,.7])
  axs[1,i].plot(distchars[:,i,1],'ro-',markerfacecolor=[.9,.7,.7])
  axs[1,i].plot(distchars[:,i,2],'b^-',markerfacecolor=[.7,.7,.9])
  axs[1,i].legend(['$\\mathbf{W_Q}$','$\\mathbf{W_K}$','$\\mathbf{W_V}$'])
  axs[1,i].set(xlabel='Transformer layer',ylabel=['Mean','L1 mean','Variance'][i])


plt.tight_layout()
plt.show()

# **Part 4: Feedforward weights**

In [None]:
# inspecting the architecture of the MLP block
model.transformer.layers[10][1].net

In [None]:
# demo of the GEGLU gated gelu operation
N = 2722

# input data
x = torch.linspace(-2,2,N)
x1 = x[:N//2]
x2 = x[N//2:]

# output data
y_pt = model.transformer.layers[10][1].net[2](x) # REVE pytorch implementation
y_mn = x1 * F.gelu(x2) # manual implementation for comparison & clarity

# plot
plt.plot(x,x,'s')
plt.plot(x1,y_pt,'o')
plt.plot(x1,y_mn,'.')

# prettify
plt.legend([f'input (N={len(x)})',f'pytorch (N={len(y_pt)})',f'manual (N={len(y_mn)})'])
plt.gca().set(xlabel='Input',ylabel='Output')
plt.show()

In [None]:
# common histogram boundaries
histedges = np.linspace(-.5,.5,81)

# initializations
distributions = np.zeros((n_layers,len(histedges)-1,2))
distchars = np.zeros((n_layers,2,2))

# loop over layers
for layeri in range(n_layers):

  # split into matrices
  expand = model.transformer.layers[layeri][1].net[1].weight.detach()
  contract = model.transformer.layers[layeri][1].net[3].weight.detach()

  # histograms
  distributions[layeri,:,0] = np.histogram(expand,bins=histedges,density=True)[0]
  distributions[layeri,:,1] = np.histogram(contract,bins=histedges,density=True)[0]

  # mean and std
  distchars[layeri,:,0] = np.array([expand.mean(), expand.var()])
  distchars[layeri,:,1] = np.array([contract.mean(), contract.var()])

# show the heatmaps
_,axs = plt.subplots(2,2,figsize=(12,6))
for i in range(2):
  axs[0,i].imshow(distributions[:,:,i],origin='lower',extent=[histedges[0],histedges[-1],0,n_layers],
                aspect='auto',cmap=plt.cm.plasma,vmin=0,vmax=3.5)
  axs[0,i].set(xlabel='Weight value',ylabel='Layer',title=f"$\\mathbf{{W}}_{'EC'[i]}$")

plt.suptitle(f'Laminar distributions of attention weights',fontweight='bold')


for i in [0,1]:
  axs[1,i].plot(distchars[:,i,0],'rs-',markerfacecolor=[.9,.7,.7])
  axs[1,i].plot(distchars[:,i,1],'go-',markerfacecolor=[.7,.9,.7])
  axs[1,i].legend(['$\\mathbf{W_E}$','$\\mathbf{W_C}$'])
  axs[1,i].set(xlabel='Transformer layer',ylabel=['Mean','Variance'][i])


plt.tight_layout()
plt.show()

In [None]:
# feedforward weights cosine similarities
whichlayer = 10

expand = model.transformer.layers[whichlayer][1].net[1].weight.detach()
contract = model.transformer.layers[whichlayer][1].net[3].weight.detach()

csMat_ex = cosine_similarity(expand,expand)
csMat_co = cosine_similarity(contract,contract)

fig,axs = plt.subplots(1,3,figsize=(12,4))
h = axs[0].imshow(csMat_ex[::40,::40],vmin=-.5,vmax=.5)
fig.colorbar(h,ax=axs[0],pad=.01,fraction=.047)
axs[0].set(xticks=[],xlabel='Neurons',yticks=[],ylabel='Neurons',title='Expansion layer')

h = axs[1].imshow(csMat_co[::10,::10],vmin=-.5,vmax=.5)
fig.colorbar(h,ax=axs[1],pad=.01,fraction=.047)
axs[1].set(xticks=[],xlabel='Neurons',yticks=[],ylabel='Neurons',title='Contraction layer')

y,x = np.histogram(csMat_ex[np.triu_indices(csMat_ex.shape[0],1)],bins=np.linspace(-1,1,201),density=True)
axs[2].plot(x[:-1],y,linewidth=2,label='Expansion')

y,x = np.histogram(csMat_co[np.triu_indices(csMat_co.shape[0],1)],bins=np.linspace(-1,1,201),density=True)
axs[2].plot(x[:-1],y,linewidth=2,label='Contraction')
axs[2].set(xlabel='Similarity values',ylabel='Density',title='Distributions',xlim=[-1,1])
axs[2].legend()

plt.tight_layout()
plt.show()

# **Part 5: Implant hooks and get activations**

In [None]:
# initialize empty dictionary
activations = {}

# attention hooks
def implant_hook_at(layer_number):
  def hook_at(module,input,output):
    activations[f'at_{layer_number}_qkv'] = output.detach()
  return hook_at

# MLP hooks
def implant_hook_ff(layer_number):
  def hook_ff(module,input,output):

    # calculate
    I = input[0]        # layer input
    X1 = module[0](I)   # RMSnorm
    X2 = module[1](X1)  # expansion
    X3 = module[2](X2)  # geglu
    X4 = module[3](X3)  # contraction (projection)

    # and store
    activations[f'ff_{layer_number}_0'] = X1.detach().numpy()
    activations[f'ff_{layer_number}_1'] = X2.detach().numpy()
    activations[f'ff_{layer_number}_2'] = X3.detach().numpy()
    activations[f'ff_{layer_number}_3'] = X4.detach().numpy()
  return hook_ff



# surgeries
handles = []
for i in range(n_layers):

  # implant attention hooks
  module = model.transformer.layers[i][0].to_qkv
  h = module.register_forward_hook(implant_hook_at(i))
  handles.append(h)

  # implant feedforward hooks
  module = model.transformer.layers[i][1].net
  h = module.register_forward_hook(implant_hook_ff(i))
  handles.append(h)

In [None]:
# run a forward pass on 10 epochs
X = torch.stack([dataset[i]['data'] for i in range(10)])
Y = model(X,positions)

# remove the hooks
for h in handles:
  h.remove()

Y.shape

In [None]:
for k,v in activations.items():
  print(f'"{k}" has shape {list(v.shape)}')

# **Part 6: Characterize QKV activations**

In [None]:
# concatenated activations from one layer
layeri = 6

wide_acts = activations[f'at_{layeri}_qkv']

plt.figure(figsize=(10,3))
plt.imshow(wide_acts[0,:,:],aspect='auto',vmin=-1,vmax=1,cmap='plasma')
plt.axvline(emb_dims,linestyle='--',color='w')
plt.axvline(2*emb_dims,linestyle='--',color='w')
plt.colorbar(pad=.01)

plt.gca().set(xticks=[],ylabel='Spatiotemporal token indices',
              xlabel='Queries dimensions         |           Keys dimensions              |           Values dimensions',
              title=f'Attention matrices in layer {layeri}')


plt.tight_layout()
plt.show()

In [None]:
# initialize: layers X matrix X feature
descriptives = torch.zeros((n_layers,3,2))

for layeri in range(n_layers):

  # split into separate matrices
  Q,K,V = torch.split(activations[f'at_{layeri}_qkv'],emb_dims,dim=-1)

  # Q: get the descriptives
  descriptives[layeri,0,0] = Q.mean()
  descriptives[layeri,0,1] = Q.std()

  descriptives[layeri,1,0] = K.mean()
  descriptives[layeri,1,1] = K.std()

  descriptives[layeri,2,0] = V.mean()
  descriptives[layeri,2,1] = V.std()


descriptives.shape

In [None]:
fig,axs = plt.subplots(1,2,figsize=(12,4))

for i in range(2):
  axs[i].plot(descriptives[:,0,i],'gs-',markerfacecolor=[.7,.9,.7],label='Q')
  axs[i].plot(descriptives[:,1,i],'ro-',markerfacecolor=[.9,.7,.7],label='K')
  axs[i].plot(descriptives[:,2,i],'b^-',markerfacecolor=[.7,.7,.9],label='V')

  axs[i].set(xlabel='Layer index',ylabel=f'{["Mean","Variance"][i]}',
             title=f'{["Mean","Variance"][i]}')
  axs[i].legend()



plt.tight_layout()
plt.show()

# **Part 7: Raw and softmax QKáµ€ scores**

In [None]:
layeri = n_layers//2

# separate one epoch into Q,K,V
Q,K,V = torch.split(activations[f'at_{layeri}_qkv'][0,:,:],emb_dims,dim=1)

# now split into heads
Q_h = torch.split(Q,head_dim,dim=1)
K_h = torch.split(K,head_dim,dim=1)

print(f'There are {len(Q_h)} heads')
print(f'Each head has size {Q_h[2].shape}')

In [None]:
# visualize
_,axs = plt.subplots(2,4,figsize=(12,4))

for i,ax in enumerate(axs.flatten()):
  ax.pcolor(Q_h[i].T,cmap='plasma',vmin=-2,vmax=2)
  ax.text(2,head_dim-1,f'Qh{i}',fontsize=12,fontweight='bold',color='k',ha='left',va='top')
  ax.text(1,head_dim-2,f'Qh{i}',fontsize=12,fontweight='bold',color='w',ha='left',va='top')
  ax.set(xticks=[],yticks=[])

# finalize
axs[-1,0].set(ylabel='Head dim',xlabel='Token position')

plt.tight_layout()
plt.show()

In [None]:
# initializations
withinhead_dp = np.array([])
acrosshead_dp = np.array([])

# loop over pairs of heads
for qi in range(n_heads):
  for ki in range(n_heads):

    # QK' dot products (and vectorized)
    dp = Q_h[qi] @ K_h[ki].t() / sqrtD
    dp = dp.numpy().flatten()

    # store in the appropriate matrix
    if qi==ki:
      withinhead_dp = np.concatenate((withinhead_dp,dp))
    else:
      acrosshead_dp = np.concatenate((acrosshead_dp,dp))

print(f'There are {len(acrosshead_dp):,} values in "across head"')
print(f'      and {len(withinhead_dp):7,} values in "within head".')

In [None]:
## visualizations
_,axs = plt.subplots(1,2,figsize=(10,4))

# and the violin plot
v = axs[0].violinplot([withinhead_dp,acrosshead_dp])

# change the colors
v['bodies'][0].set_facecolor([.7,.9,.7])
v['bodies'][1].set_facecolor([.9,.7,.7])
v['bodies'][0].set_alpha([.9])
v['bodies'][1].set_alpha([.9])
v['cbars'].set_edgecolor('w')
v['cmins'].set_edgecolor('w')
v['cmaxes'].set_edgecolor('w')

axs[0].axhline(0,linestyle='--',color=[.7,.7,.7],zorder=-3)
axs[0].set(xticks=[1,2],xticklabels=['Same head','Diff heads'],
              ylabel='QK$^\\top$ dot products',title='A) Raw attention scores',xlim=[.5,2.5])


# distributions
y,x = np.histogram(withinhead_dp,bins='fd',density=True)
axs[1].plot(x[:-1],y,'g',linewidth=2,label='Same head')

y,x = np.histogram(acrosshead_dp,bins='fd',density=True)
axs[1].plot(x[:-1],y,'r',linewidth=2,label='Diff heads')

axs[1].legend()
axs[1].set(xlabel='Dot product value',ylabel='Density',title='B) Distributions')
axs[1].axvline(0,linestyle='--',color=[.7,.7,.7])

plt.suptitle(f'Data from transformer layer {layeri}',fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
histedges = np.linspace(-15,15,101)

layerHists = np.zeros((n_layers,3,len(histedges)-1))


# loop over layers
for layeri in range(n_layers):

  # get the activations
  Q,K,V = torch.split(activations[f'at_{layeri}_qkv'],emb_dims,dim=-1)
  Qh = Q.view(10,100,n_heads,head_dim).permute(0,2,1,3)
  Kh = K.view(10,100,n_heads,head_dim).permute(0,2,1,3)


  # re-initialize
  withinhead_dp = np.array([])
  acrosshead_dp = np.array([])
  withinhead_sm = np.array([])

  # loop over pairs of heads
  for qi in range(n_heads):
    for ki in range(n_heads):

      # dot products per head-pair
      dp = (Qh[:,qi,:,:] @ Kh[:,ki,:,:].transpose(-2,-1)) / sqrtD

      # store in the appropriate matrix
      if qi==ki:
        withinhead_dp = np.concatenate((withinhead_dp,dp.flatten()))
        withinhead_sm = np.concatenate((withinhead_sm,
                                        torch.softmax(dp,dim=-1).flatten()))
      else:
        acrosshead_dp = np.concatenate((acrosshead_dp,dp.flatten()))


  # distributions
  y,_ = np.histogram(acrosshead_dp,bins=histedges,density=True)
  layerHists[layeri,0,:] = y

  y,_ = np.histogram(withinhead_dp,bins=histedges,density=True)
  layerHists[layeri,1,:] = y

  y,_ = np.histogram(withinhead_sm,bins=np.linspace(0,1,len(histedges)),density=True)
  layerHists[layeri,2,:] = y

In [None]:
_,axs = plt.subplots(1,3,figsize=(12,4))

axs[0].imshow(layerHists[:,0,:],origin='lower',aspect='auto',cmap='magma',
              extent=[histedges[0],histedges[-1],0,n_layers],vmin=0,vmax=.15)
axs[0].axvline(0,linestyle='--',color='k',linewidth=2)

axs[1].imshow(layerHists[:,1,:],origin='lower',aspect='auto',cmap='magma',
              extent=[histedges[0],histedges[-1],0,n_layers],vmin=0,vmax=.15)
axs[1].axvline(0,linestyle='--',color='k',linewidth=2)

axs[2].imshow(layerHists[:,2,:],origin='lower',aspect='auto',cmap='magma',
              extent=[0,1,0,n_layers],vmin=0,vmax=.15)

axs[0].set(xlabel='$\\mathbf{QK^\\top}$ activation value',ylabel='Transformer layer',title='A) Across heads (raw)')
axs[1].set(xlabel='$\\mathbf{QK^\\top}$ activation value',ylabel='Transformer layer',title='B) Within heads (raw)')
axs[2].set(xlabel='$\\mathbf{QK^\\top}$ activation value',ylabel='Transformer layer',title='C) Within heads (softmax)')

plt.tight_layout()
plt.show()

# **Part 8: Attention head entropy and sparseness**

In [None]:
# initialize
entropies = np.zeros((n_layers,n_heads))


# loop over layers
for layeri in range(n_layers):

  # get the activations
  Q,K,V = torch.split(activations[f'at_{layeri}_qkv'],emb_dims,dim=-1)
  Qh = Q.view(10,100,n_heads,head_dim).permute(0,2,1,3)
  Kh = K.view(10,100,n_heads,head_dim).permute(0,2,1,3)


  # loop over pairs of heads
  for headi in range(n_heads):

    # softmax attention scores
    dp = (Qh[:,headi,:,:] @ Kh[:,headi,:,:].transpose(-2,-1)) / sqrtD
    sm = torch.softmax(dp,dim=-1)

    # estimate entropy
    kde = gaussian_kde(sm.flatten())   # KDE method
    kde_y = kde(np.linspace(0,1,100))  # KDE values
    kde_y /= kde_y.sum()               # normalize to prob dist
    entropies[layeri,headi] = -np.sum(kde_y * np.log2(kde_y+1e-10))


In [None]:
plt.figure(figsize=(10,3.5))

for layeri in range(n_layers):
  plt.plot(np.ones(n_heads)*layeri,entropies[layeri,:],'wo',markersize=8,alpha=.7,markerfacecolor=plt.cm.plasma(layeri/n_layers))

plt.plot(np.mean(entropies,axis=1),'w',linewidth=2,zorder=-100)

plt.gca().set(xlabel='Layer',ylabel='Entropy')
plt.show()

In [None]:
plt.figure(figsize=(8,3))
plt.hist(entropies.flatten(),bins=60,color=[.7,.7,.9],edgecolor='w')
plt.gca().set(xlabel='Within-head entropy',ylabel='Count')
plt.show()

# **Part 9: Feedforward activation characteristics**

In [None]:
ff_labels = ['0) Normalize','1) Expansion','2) GEGLU','3) Projection']

whichlayer = 10

plt.figure(figsize=(10,4))
for i in range(4):
  y,x = np.histogram(activations[f'ff_{whichlayer}_{i}'],bins=np.linspace(-2,2,201),density=True)
  plt.plot(x[:-1],y,label=ff_labels[i],linewidth=2)

plt.axvline(0,linestyle='--',color='w',linewidth=.4,zorder=-10)
plt.gca().set(xlabel='Activation value',ylabel='Density',xlim=x[[0,-1]],
              title=f'Feedforward activations from layer {whichlayer}')

plt.legend()
plt.show()

In [None]:
histedges = np.linspace(-1.2,1.2,101)

ff_dists = np.zeros((n_layers,len(histedges)-1,4))


for layeri in range(n_layers):

  # get histograms
  for i in range(4):
    ff_dists[layeri,:,i] = np.histogram(activations[f'ff_{layeri}_{i}'],bins=histedges,density=True)[0]


fig,axs = plt.subplots(2,2,figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
  ax.imshow(ff_dists[:,:,i],aspect='auto',extent=[histedges[0],histedges[-1],0,n_layers-1],cmap='magma',vmin=0,vmax=3)
  ax.set(xlabel='Activation value',ylabel='Layer',title=f'{ff_labels[i]}')

plt.tight_layout()
plt.show()