In [1]:
%load_ext autoreload
%autoreload 2

from nnsight import LanguageModel
from dictionary_learning.buffer import ActivationBuffer
from dictionary_learning.training import trainSAE
from transformer_lens import HookedTransformer, utils
import torch
import os
from tqdm import tqdm
from tqdm.notebook import tqdm as tqdm_notebook

In [2]:

model = LanguageModel(
    # 'EleutherAI/pythia-70m-deduped', # this can be any Huggingface model
    'EleutherAI/pythia-2.8b-deduped',
    device_map = 'cuda:0'
)
submodule = model.gpt_neox.layers[1].mlp # layer 1 MLP
activation_dim = 512 # output dimension of the MLP
dictionary_size = 16 * activation_dim

In [None]:
tl_model = HookedTransformer

In [3]:
import torch
def display_memory():
    total = torch.cuda.get_device_properties(0).total_memory
    r = torch.cuda.memory_reserved(0)
    a = torch.cuda.memory_allocated(0)
    print(f"{a*1e-9} allocated, {r*1e-9} reserved, {total*1e-9} total")

In [4]:

# data much be an iterator that outputs strings
# data = iter([
#     'This is some example data',
#     'In real life, for training a dictionary',
#     'you would need much more data than this'
# ])

from datasets import load_dataset
import torch

# Load the dataset
# train_dataset = load_dataset('wikitext', 'wikitext-103-v1', split='train[:1000000]')
train_dataset = load_dataset('Skylion007/openwebtext', split='train[:100]')
def yield_sentences(data_split):
    for example in data_split:
        text = example['text']
        sentences = text.split('\n')
        for sentence in sentences:
            if sentence:  # skip empty lines
                yield sentence

# Creating an iterator for training sentences
train_sentences = yield_sentences(train_dataset)

# for i in range(10):
#     print(next(train_sentences))

buffer = ActivationBuffer(
    train_sentences,
    model,
    submodule,
    out_feats=activation_dim, # output dimension of the model component
    n_ctxs=3e3,
    in_batch_size=128, # batch size for the model
    out_batch_size=128*16, # batch size for the buffer
) # buffer will return batches of tensors of dimension = submodule's output dimension


In [5]:
display_memory()

0.0 allocated, 0.0 reserved, 84.986691584 total


In [6]:
from tqdm.notebook import tqdm
# # train the sparse autoencoder (SAE)
ae = trainSAE(
    buffer,
    activation_dim,
    dictionary_size,
    lr=3e-4,
    sparsity_penalty=1e-3,
    device='cuda:0',
    tqdm_style=tqdm
)

0it [00:00, ?it/s]

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


refreshing buffer...
buffer refreshed...
acts.shape=torch.Size([256, 512])
step 0 memory: 0.47353139200000005 allocated, 0.557842432 reserved, 84.986691584 total
step 0 MSE loss: 0.2877471446990967, sparsity loss: 479.6265563964844
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
refreshing buffer...
buffer refreshed...
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=torch.Size([256, 512])
acts.shape=to

In [None]:
from dictionary_learning.dictionary import AutoEncoder
from dictionary_learning.buffer import ActivationBuffer
from dictionary_learning.training import ConstrainedAdam, sae_loss, resample_neurons

def trainSAE(
        activations, # a generator that outputs batches of activations
        activation_dim, # dimension of the activations
        dictionary_size, # size of the dictionary
        lr,
        sparsity_penalty,
        entropy=False,
        steps=None, # if None, train until activations are exhausted
        warmup_steps=1000, # linearly increase the learning rate for this many steps
        resample_steps=25000, # how often to resample dead neurons
        save_steps=None, # how often to save checkpoints
        save_dir=None, # directory for saving checkpoints
        log_steps=1000, # how often to print statistics
        device='cpu',
        tqdm_style=tqdm
        ):
    """
    Train and return a sparse autoencoder
    """
    ae = AutoEncoder(activation_dim, dictionary_size).to(device)
    alives = torch.zeros(dictionary_size).bool().to(device) # which neurons are not dead?

    # set up optimizer and scheduler
    optimizer = ConstrainedAdam(ae.parameters(), ae.decoder.parameters(), lr=lr)
    def warmup_fn(step):
        if step < warmup_steps:
            return step / warmup_steps
        else:
            return 1.
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_fn)

    for step, acts in enumerate(tqdm_style(activations, total=steps)):
        if steps is not None and step >= steps:
            break

        if isinstance(acts, torch.Tensor): # typical casse
            acts = acts.to(device)
        elif isinstance(acts, tuple): # for cases where the autoencoder input and output are different
            acts = tuple(a.to(device) for a in acts)

        # print(f"{acts.shape=}")
        optimizer.zero_grad()
        loss = sae_loss(acts, ae, sparsity_penalty, entropy, separate=False)
        loss.backward()
        optimizer.step()
        scheduler.step()

        # deal with resampling neurons
        if resample_steps is not None:
            with torch.no_grad():
                if isinstance(acts, tuple):
                    in_acts = acts[0]
                else:
                    in_acts = acts
                dict_acts = ae.encode(in_acts)
                alives = torch.logical_or(alives, (dict_acts != 0).any(dim=0))
                if step % resample_steps == resample_steps // 2:
                    alives = torch.zeros(dictionary_size).bool().to(device)
                if step % resample_steps == resample_steps - 1:
                    deads = ~alives
                    if deads.sum() > 0:
                        print(f"resampling {deads.sum().item()} dead neurons")
                        resample_neurons(deads, acts, ae, optimizer)

        # logging
        if log_steps is not None and step % log_steps == 0:
            total = torch.cuda.get_device_properties(0).total_memory
            r = torch.cuda.memory_reserved(0)
            a = torch.cuda.memory_allocated(0)
            print(f"step {step} memory: {a*1e-9} allocated, {r*1e-9} reserved, {total*1e-9} total")
            with torch.no_grad():
                mse_loss, sparsity_loss = sae_loss(acts, ae, sparsity_penalty, entropy, separate=True)
                print(f"step {step} MSE loss: {mse_loss}, sparsity loss: {sparsity_loss}")
                # dict_acts = ae.encode(acts)
                # print(f"step {step} % inactive: {(dict_acts == 0).all(dim=0).sum() / dict_acts.shape[-1]}")
                # if isinstance(activations, ActivationBuffer):
                #     tokens = activations.tokenized_batch().input_ids
                #     loss_orig, loss_reconst, loss_zero = reconstruction_loss(tokens, activations.model, activations.submodule, ae)
                #     print(f"step {step} reconstruction loss: {loss_orig}, {loss_reconst}, {loss_zero}")

        # saving
        if save_steps is not None and save_dir is not None and step % save_steps == 0:
            if not os.path.exists(os.path.join(save_dir, "checkpoints")):
                os.mkdir(os.path.join(save_dir, "checkpoints"))
            torch.save(
                ae.state_dict(), 
                os.path.join(save_dir, "checkpoints", f"ae_{step}.pt")
                )

    return ae