## Module 4: Explaining Models (Dictionary Learning)

In this demo, we'll use dictionary learning to analyze how the final hidden
layer of a GPTNeo model organizes articles from the fineweb-edu dataset. This
hidden layer is 768-dimensional, but analyzing individual neurons is not an
efficient way to work. We will find that looking at the learned dictionary atoms
associated with this layer's activations are much more interesting.

The libraries below link to data and models in huggingface. They are already
included in the iisa312 environment, defined in this [yaml file](https://github.com/krisrs1128/talks/blob/master/2024/20241230/examples/environment-iisa312.yaml).

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, GPTNeoModel
import torch
import numpy as np
np.random.seed(20241230)

The command below defines a data loader for the
[fineweb-edu](https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu)
dataset. This is a 7.5TB dataset, so we'll only try working with a streaming
version, which allows us to read a few articles at a time (we'll be looking at a
tiny fraction of the original data, but it will be enough to see some
interesting structure).

In [2]:
fw = load_dataset("HuggingFaceFW/fineweb-edu", name="CC-MAIN-2024-10", split="train", streaming=True)

Let's save 2500 articles on which to extract activations. You can see the first
200 characters of the raw text from a few articles below. They are all somewhat
academic in style, but they range quite dramatically in the topics they discuss.

In [None]:
n_stream = 2500
texts = []
for x in fw:
    texts.append(x["text"])
    if len(texts) > n_stream: break

[f"{s[:200]}..." for s in texts[:10]]

The block below extracts embeddings from the final hidden state
(`.hidden_states[-1]`) in a GPTNeo model. Notice that we're averaging the hidden dimension across all tokens in the text. In theory, we could analyze activations within smaller stretches of text, but we are aiming more for simplicity than completeness.

In [4]:
def extract_embeddings(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=False)
    # fill in this function
    pass

# Load pre-trained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125m")
model = GPTNeoModel.from_pretrained("EleutherAI/gpt-neo-125m")
model = model.eval()

The block below applies `extract_embeddings` to all the articles we downloaded
above. This is relatively fast on a machine with a GPU, but since we're running
all the demos on our laptops, it would be quite slow, so I've commented it out.
Instead, we'll just download embeddings that I extracted in advance.

In [5]:
import pickle

with open("embeddings.pkl", "rb") as f:
    embeddings = pickle.load(f)

The functions below define the dictionary learning algorithm. We will alternate
`alpha_update` with `dictionary_update` to learn the $\alpha$ and $\Phi$
parameters, respectively. For `alpha_update`, we are using a standard
implementation of the fast iterative soft thresholding algorithm (this is where
the seemingly arbitrary formulas like $1 + \sqrt{1 + 4t^{2}}$ are coming from).
Notice that we are constraining $\alpha \succeq 0$ using the `clamp` argument of
`threshold`.

In [6]:
def threshold(x, lambd, clamp=True):
    x_ = np.sign(x) * np.maximum(np.abs(x) - lambd, 0)
    if clamp:
        x_[x_ < 0] = 0
    return x_

def alpha_update(X, D, alpha, lambd, max_iter=20):
    L = np.linalg.norm(D, ord=2) ** 2
    a = np.zeros_like(alpha)
    a_ = np.copy(a)
    
    t = 1
    for _ in range(max_iter):
        a_prev = np.copy(a)
        a = threshold(a_ + (1 / L) * D.T @ (X - D @ a_), lambd / L)
        t_new = (1 + np.sqrt(1 + 4 * t ** 2)) / 2
        a_ = a + ((t - 1) / t_new) * (a - a_prev)
        t = t_new
    
    return a

def dictionary_update(X, a):
    return X @ np.linalg.pinv(a)

Let's run 25 iterations of dictionary learning with $K = 250$.

In [None]:
from tqdm import tqdm

K = 250
X = torch.stack(embeddings).numpy().T
D, N = X.shape
Phi = np.random.randn(D, K)  # Initial dictionary
alpha = np.random.rand(K, N)  # Initial sparse codes
l = 0.5  # Regularization parameter

### define the dictionary learning optimization

Finally, we can look at articles that have especially high activations
$\alpha_{i}$ on subsets of articles. For example, it seems the first dictionary
atom $\phi_{1}$ is mainly related to languages.

In [None]:
top_ix = np.argsort(alpha[0, :])[-10:][::-1]
[f"{texts[i][:200]}..." for i in top_ix]