# Entrenamiento de un SAE sobre las salidad y activaciones de la MLP intermedia de llama3.2 1B

## Costo de entrenamiento

Generalmente, el costo computacional de un LLM está dominado por la evaluación
de sus MLPs (referencia a el seminario en una universidad de un ex empleado
de anthropic)

Una aprocimación simple del costo de entrenamiento de un modelo es

$$
  6ND 
$$
esto en términos de FLOPs (operaciones de punto flotante).
(citar a chinchilla scaling laws)

donde $N$ es el número de parámetros y $D$ la cantidad de muestras en el
conjunto de entrenamiento.

Esto se debe a que, generalmente, cada parámetro actua en una multiplicaciónl
y en una suma de punto flotante, dandonos un costo de $2ND$ tan solo en el
forward pass. Tipicamente, el costo de el backward pass es el doble del forward
pass, haciendo que su costo sea $4ND$. Sumando tenemos el resultado previamente
mencionado.

Sea $N_l$ el número de parámetros de llama.
Como nosotros vamos a entrenar un autoencoder sobre un MLP a la mitad de llama,
solo necesitamos evaluar esa primera mitad. Además, no correremos el backward
pass sobre los parámetros de llama, pues no buscamos modificarlos, es decir, los
mantendrémos fijos. Por esto, tenemos que el FLOPs realizados por tal mitad del
modelo llama es

$$
  N_l D
$$

En cuanto al SAE, solo considerando el costo de aplicar sus matrices, tenemos

$$
  6 (2d_\text{in}d_\text{sae})D
$$

En el caso de gemmascope, entre todos los SAEs que entrenaron, los más pequeños
entrenados en las salidas de las MLPs, se entrenaron en 4 mil millones de
vectores de activaciones, con la dimencion de los vectores en el stream
recidual (y por lo tanto, de la salida de las capas MLP), es 2048, con
$d_\text{sae} = 2028 * 8$.

Deseamos encontrar hiperparámetros para entrenar SAEs, para esto:
- Usamos una primera aproximación razonable modificando los hiperparámetros para
  los autoencoders más pequeños entrenados en salidas de las MLPs de gemma 2
  2 B
- Ajustamos una power law en base a 2 entrenamientos de SAEs más pequeñas,
  usando el mismo learning rate.
- Ajustamos una power law para el learning rate con los hiperparámetos optimos
  que estimó el paso anterior.

Si ignoramos la relación posicional de los tokens y asumimos una distribución
uniforme, tenemos que la entropía por token es

$$
  \log_2 (\text{tamaño del vocabulario})
$$
ya que el vocabulario de gemma2 2B es 256000 y el de llama es 128000, obtenemos
que el número de tokens equivalente sería 4.2 mil millones. Ya que nuestro
modelo es la mitad de tamaño de gemma2 2B, una primera cantidad de datos
razonable para entrenar nuesto sae más grande es $2.1B$

En el caso de llama3.2 1B, eso resultaría en

$$
  N_l D = (2.1 \times 10^9)(1.2 \times 10^9) \approx 2.5 \times 10^{18}
$$
Una RTX 4090 puede realisar cada segundo un máximo de $165 \times 10^{12}$
operaciones con tensores de 16 bits y acumulador de 32 bits (referencia
al reporte técnico v1.0.1), luego, estimamos 4.2 horas de entrenamiento
tan solo considerando la computación en el modelo llama.

Ahora, para estimar las horas-RTX4090 para el autoencoder, en el caso de
entrenarlo en la salida de la MLP intermedia, tendríamos

$$
    6(2.1 \times 10^9)(2048^2)(8)(2) = 8.5 \times 10^{17}
$$

# El modelo

In [1]:
import torch
from torch import nn
import jaxtyping
import dataclasses
from tqdm import tqdm
from torch.nn import functional as F
from torch.linalg import vector_norm
import math
import datasets
import numpy as np
from torch.utils.data import Dataset, DataLoader

In [2]:
dtype = torch.float32
device = "cuda" if torch.cuda.is_available else "cpu"

In [3]:
class Step(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, threshold):
        ctx.save_for_backward(x, threshold)
        return (x > threshold).to(x.dtype)

    @staticmethod
    def backward(ctx, grad_output):
        bandwidth = 0.001
        x, threshold = ctx.saved_tensors

        grad_threshold = torch.where(
            abs(F.relu(x) - threshold) < bandwidth/2,
            -1.0/bandwidth, 0)
        
        return torch.zeros_like(x), grad_threshold * grad_output

In [4]:
class Sae(nn.Module):
    def __init__(self, d_in, d_sae, use_pre_enc_bias=True, **kwargs):
        super().__init__(**kwargs)
        self.enc = nn.Linear(d_in, d_sae, dtype=dtype)
        self.dec = nn.Linear(d_sae, d_in, dtype=dtype)
        with torch.no_grad():
            # normalize each of the d_sae dictonary vectors
            self.dec.weight /= vector_norm(self.dec.weight, dim=1, keepdim=True)
            self.enc.weight.copy_(self.dec.weight.clone().t())
            self.enc.bias.copy_(torch.zeros_like(self.enc.bias))
            self.dec.bias.copy_(torch.zeros_like(self.dec.bias))
        self.log_threshold = nn.Parameter(
            torch.log(torch.full((d_sae,), 0.001, dtype=dtype)))
        self.use_pre_enc_bias = use_pre_enc_bias
        def project_out_parallel_grad(dim, tensor):
            @torch.no_grad
            def hook(grad_in):
                # norm along dim=dim of the tensor is assumed to be 1 as we
                # are going to normalize it after every grad update
                dot = (tensor * grad_in).sum(dim=dim, keepdim=True)
                return grad_in - dot * tensor
            return hook

        self.dec.weight.register_hook(
            project_out_parallel_grad(1, self.dec.weight))
                

    def forward(self,
        x,
        return_mask=False,
        return_l0=True,
        return_reconstruction_loss=True,
    ):
        "We compute this much here so that compile() can do its magic"
        # as per train_gpt2.py on karpathy's llm.c repo, there are performance
        # reasons not to return stuff
        d = {}
        original_input = x
        if self.use_pre_enc_bias:
            x = x - self.dec.bias
        
        x = self.enc(x)
        threshold = torch.exp(self.log_threshold)
        s = Step.apply(x, threshold)
        if return_mask:
            d['mask'] = s
        if return_l0:
            d['l0'] = s.float().mean(0).sum(-1)
        if not return_reconstruction_loss:
            return d
        x = x*s
        x = self.dec(x)

        d['reconstruction'] = ((x - original_input).pow(2)).mean(0).sum()

        return d

In [5]:
def cosine_schedule_with_warmup(
    current_step: int,
    warmup_steps: int,
    total_steps: int
    ):
    if current_step < warmup_steps:
        lr =  (1 + current_step) / warmup_steps
        return lr
    progress = (current_step - warmup_steps) / (total_steps - warmup_steps)
    lr =  0.5 * (1 + math.cos(math.pi * progress))
    return lr

In [6]:
def sparsity_schedule(step, warmup_steps, max_sparsity_coeff):
    if step >= warmup_steps:
        return max_sparsity_coeff
    return max_sparsity_coeff*((step+1) / warmup_steps)

In [7]:
from datasets import load_dataset

# Load the dataset
ds = load_dataset('mech-interp-uam/llama-mlp8-outputs')
ds.set_format('numpy')

# Check the dataset structure
print(f"Dataset loaded successfully: {len(ds['train'])} examples")
print(f"Features: {ds['train'].features}")
print(f"First example shape: {ds['train'][0]['activations'].shape}")

Dataset loaded successfully: 668501 examples
Features: {'activations': Sequence(feature=Value(dtype='float16', id=None), length=2048, id=None)}
First example shape: (2048,)


In [8]:
B = 1024
n = 100
sample_norms = (
    ds['train']
    .batch(B)
    .shuffle()
    .take(n)
    .map(lambda row: {"norm": np.linalg.norm(row["activations"], axis=1).mean()},
         remove_columns=['activations'])
)
norm = None
for i, sample_norm in enumerate(sample_norms):
    current_norm = sample_norm['norm']
    if i == 0:
        norm = current_norm
        continue
    norm = i/(i+1) * norm + 1/(i+1) * current_norm

print(norm)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

3.4179318


In [9]:
import torch

class ActivationsDataset(Dataset):
    def __init__(self, hf_dataset):
        self.data = hf_dataset

    def __getitem__(self, idx):
        # Return as float16, on modern GPUs conversion from float 16 to 32 is
        # free compared to matmults or so I was told
        return torch.tensor(self.data[idx]['activations'])

    def __len__(self):
        return len(self.data)


In [10]:
batch = 1024
dataset = ActivationsDataset(ds['train'])
dataloader = DataLoader(
    dataset,
    batch_size=batch,
    shuffle=True,
    num_workers=12,
    pin_memory=True,
    prefetch_factor=32,
    persistent_workers=True,
    drop_last=True,
)

In [11]:
len(dataloader)

652

In [None]:
#torch.set_float32_matmul_precision('high')
steps = 2**16
max_lr = 7e-5
d_in = 2048
d_sae = 2048*8
model = Sae(d_in, d_sae)
model.to('cuda')
model.compile()
warmup_steps=1000
sparcity_warmup_steps=128000
total_steps=128000 #for now
optimizer = torch.optim.Adam([
    {"params": model.enc.parameters(), "lr":   max_lr, "betas":(0.0, 0.999)},
    {"params": model.dec.parameters(), "lr":   max_lr, "betas":(0.0, 0.999)},
    {"params": model.log_threshold,    "lr": 2*max_lr, "betas":(0.99,0.999)},
])
max_sparsity_coeff = 0.0004

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda step: cosine_schedule_with_warmup(step, warmup_steps, total_steps)
)
# train_ds = fake_train_loader(batch, d_in, total_steps)
total_step = 0
should_break = False
for epoch in range(1000):
    for step, x in enumerate(tqdm(dataloader)):
        x = x.to("cuda", non_blocking=True).to(dtype)
        x /= 3.4 # this is supposed to be the expected norm
        optimizer.zero_grad()
        d = model(x)
        reconstruction_loss, l0 = d['reconstruction'], d['l0']
        sparsity_coefficient = sparsity_schedule(total_step, sparcity_warmup_steps, max_sparsity_coeff)
        loss = reconstruction_loss + sparsity_coefficient * l0
        # log losses, compute stats, etc
        grad = loss.backward()
        # norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        # metrics
        if (total_step % 5000) == 0:
            with torch.no_grad():
                # print metrics
                print(f"reconstruction={reconstruction_loss.item()}")
                print(f"l0={l0.item()}")
                # print(f"norm={norm.item()}")
                print(f"{sparsity_coefficient=}")
        optimizer.step()
        # TODO: sparsity_coefficient scheduler
        # print(scheduler.get_lr())
        scheduler.step()

        # normalize
        with torch.no_grad():
            wdecnorm = vector_norm(model.dec.weight, dim=1, keepdim=True)
            model.dec.weight /= wdecnorm
    # print(f"epoch loss: {loss.detach().item()}")
        total_step +=1
        if total_step > total_steps:
            should_break = True
            break
    if should_break:
        break



  0%|          | 0/652 [00:00<?, ?it/s]

reconstruction=0.31900954246520996
l0=7344.169921875
sparsity_coefficient=3.1250000000000003e-09


100%|██████████| 652/652 [00:18<00:00, 35.56it/s]
100%|██████████| 652/652 [00:15<00:00, 41.91it/s]
100%|██████████| 652/652 [00:15<00:00, 42.90it/s]
100%|██████████| 652/652 [00:15<00:00, 42.56it/s]
100%|██████████| 652/652 [00:15<00:00, 41.93it/s]
100%|██████████| 652/652 [00:15<00:00, 40.76it/s]
100%|██████████| 652/652 [00:15<00:00, 41.09it/s]
 68%|██████▊   | 446/652 [00:10<00:04, 44.49it/s]

reconstruction=0.042265087366104126
l0=9552.5205078125
sparsity_coefficient=1.5628125000000002e-05


100%|██████████| 652/652 [00:15<00:00, 43.01it/s]
100%|██████████| 652/652 [00:14<00:00, 43.51it/s]
100%|██████████| 652/652 [00:15<00:00, 42.13it/s]
100%|██████████| 652/652 [00:15<00:00, 41.18it/s]
100%|██████████| 652/652 [00:16<00:00, 40.18it/s]
100%|██████████| 652/652 [00:15<00:00, 42.12it/s]
100%|██████████| 652/652 [00:15<00:00, 42.23it/s]
100%|██████████| 652/652 [00:17<00:00, 38.21it/s]
 33%|███▎      | 217/652 [00:05<00:09, 44.51it/s]

reconstruction=0.039009954780340195
l0=7292.310546875
sparsity_coefficient=3.1253125e-05


100%|██████████| 652/652 [00:15<00:00, 43.01it/s]
100%|██████████| 652/652 [00:15<00:00, 42.83it/s]
100%|██████████| 652/652 [00:15<00:00, 42.27it/s]
100%|██████████| 652/652 [00:15<00:00, 43.33it/s]
100%|██████████| 652/652 [00:15<00:00, 41.72it/s]
100%|██████████| 652/652 [00:15<00:00, 41.72it/s]
100%|██████████| 652/652 [00:16<00:00, 40.35it/s]
100%|██████████| 652/652 [00:15<00:00, 41.21it/s]
  0%|          | 1/652 [00:00<04:02,  2.69it/s]

reconstruction=0.042998429387807846
l0=4478.564453125
sparsity_coefficient=4.6878125e-05


100%|██████████| 652/652 [00:15<00:00, 42.87it/s]
100%|██████████| 652/652 [00:17<00:00, 37.81it/s]
100%|██████████| 652/652 [00:16<00:00, 40.50it/s]
100%|██████████| 652/652 [00:15<00:00, 41.91it/s]
100%|██████████| 652/652 [00:15<00:00, 41.87it/s]
100%|██████████| 652/652 [00:15<00:00, 42.95it/s]
100%|██████████| 652/652 [00:14<00:00, 43.98it/s]
 68%|██████▊   | 446/652 [00:10<00:04, 42.25it/s]

reconstruction=0.06718900799751282
l0=2394.63671875
sparsity_coefficient=6.250312500000001e-05


100%|██████████| 652/652 [00:14<00:00, 44.05it/s]
100%|██████████| 652/652 [00:16<00:00, 40.31it/s]
100%|██████████| 652/652 [00:15<00:00, 41.24it/s]
100%|██████████| 652/652 [00:15<00:00, 42.58it/s]
100%|██████████| 652/652 [00:15<00:00, 41.32it/s]
100%|██████████| 652/652 [00:15<00:00, 42.90it/s]
100%|██████████| 652/652 [00:14<00:00, 43.77it/s]
100%|██████████| 652/652 [00:15<00:00, 43.44it/s]
 35%|███▌      | 230/652 [00:05<00:10, 41.37it/s]

reconstruction=0.08861327171325684
l0=1413.4609375
sparsity_coefficient=7.812812500000001e-05


100%|██████████| 652/652 [00:15<00:00, 41.48it/s]
100%|██████████| 652/652 [00:15<00:00, 41.87it/s]
100%|██████████| 652/652 [00:15<00:00, 43.08it/s]
100%|██████████| 652/652 [00:15<00:00, 41.33it/s]
100%|██████████| 652/652 [00:15<00:00, 42.72it/s]
100%|██████████| 652/652 [00:15<00:00, 41.82it/s]
100%|██████████| 652/652 [00:15<00:00, 43.17it/s]
100%|██████████| 652/652 [00:16<00:00, 38.96it/s]
  2%|▏         | 13/652 [00:00<00:26, 24.04it/s]

reconstruction=0.09703446924686432
l0=1044.912109375
sparsity_coefficient=9.375312500000001e-05


100%|██████████| 652/652 [00:15<00:00, 41.06it/s]
100%|██████████| 652/652 [00:15<00:00, 42.41it/s]
100%|██████████| 652/652 [00:16<00:00, 40.27it/s]
100%|██████████| 652/652 [00:16<00:00, 40.56it/s]
100%|██████████| 652/652 [00:15<00:00, 41.12it/s]
100%|██████████| 652/652 [00:16<00:00, 40.22it/s]
100%|██████████| 652/652 [00:15<00:00, 42.30it/s]
 69%|██████▉   | 451/652 [00:10<00:04, 46.28it/s]

reconstruction=0.10535077750682831
l0=893.7041015625
sparsity_coefficient=0.000109378125


100%|██████████| 652/652 [00:15<00:00, 41.90it/s]
100%|██████████| 652/652 [00:15<00:00, 43.22it/s]
100%|██████████| 652/652 [00:15<00:00, 41.87it/s]
100%|██████████| 652/652 [00:15<00:00, 43.23it/s]
100%|██████████| 652/652 [00:15<00:00, 41.26it/s]
100%|██████████| 652/652 [00:16<00:00, 39.94it/s]
100%|██████████| 652/652 [00:15<00:00, 42.38it/s]
100%|██████████| 652/652 [00:15<00:00, 41.55it/s]
 35%|███▌      | 229/652 [00:05<00:10, 40.67it/s]

reconstruction=0.1057296097278595
l0=784.95703125
sparsity_coefficient=0.000125003125


100%|██████████| 652/652 [00:15<00:00, 42.76it/s]
100%|██████████| 652/652 [00:15<00:00, 42.99it/s]
100%|██████████| 652/652 [00:15<00:00, 41.62it/s]
100%|██████████| 652/652 [00:15<00:00, 42.19it/s]
100%|██████████| 652/652 [00:15<00:00, 42.16it/s]
100%|██████████| 652/652 [00:14<00:00, 43.85it/s]
100%|██████████| 652/652 [00:15<00:00, 42.40it/s]
100%|██████████| 652/652 [00:15<00:00, 42.54it/s]
  2%|▏         | 13/652 [00:00<00:26, 24.51it/s]

reconstruction=0.10894829034805298
l0=717.8662109375
sparsity_coefficient=0.000140628125


100%|██████████| 652/652 [00:16<00:00, 39.11it/s]
100%|██████████| 652/652 [00:15<00:00, 41.33it/s]
100%|██████████| 652/652 [00:17<00:00, 36.68it/s]
100%|██████████| 652/652 [00:15<00:00, 42.32it/s]
100%|██████████| 652/652 [00:14<00:00, 43.68it/s]
100%|██████████| 652/652 [00:15<00:00, 42.40it/s]
100%|██████████| 652/652 [00:15<00:00, 40.95it/s]
 68%|██████▊   | 445/652 [00:11<00:06, 32.02it/s]

reconstruction=0.11511114239692688
l0=664.91015625
sparsity_coefficient=0.000156253125


100%|██████████| 652/652 [00:16<00:00, 40.35it/s]
100%|██████████| 652/652 [00:16<00:00, 40.05it/s]
100%|██████████| 652/652 [00:15<00:00, 43.34it/s]
100%|██████████| 652/652 [00:15<00:00, 43.07it/s]
100%|██████████| 652/652 [00:15<00:00, 42.30it/s]
100%|██████████| 652/652 [00:15<00:00, 41.77it/s]
100%|██████████| 652/652 [00:16<00:00, 39.81it/s]
100%|██████████| 652/652 [00:15<00:00, 42.89it/s]
 37%|███▋      | 241/652 [00:06<00:09, 43.98it/s]

reconstruction=0.11739729344844818
l0=618.78125
sparsity_coefficient=0.000171878125


100%|██████████| 652/652 [00:15<00:00, 41.06it/s]
100%|██████████| 652/652 [00:15<00:00, 42.97it/s]
100%|██████████| 652/652 [00:15<00:00, 42.32it/s]
100%|██████████| 652/652 [00:15<00:00, 41.74it/s]
100%|██████████| 652/652 [00:15<00:00, 41.19it/s]
100%|██████████| 652/652 [00:16<00:00, 38.38it/s]
100%|██████████| 652/652 [00:15<00:00, 42.82it/s]
100%|██████████| 652/652 [00:15<00:00, 42.30it/s]
  4%|▍         | 25/652 [00:00<00:18, 33.47it/s]

reconstruction=0.11883846670389175
l0=576.2802734375
sparsity_coefficient=0.000187503125


100%|██████████| 652/652 [00:15<00:00, 40.98it/s]
100%|██████████| 652/652 [00:16<00:00, 40.33it/s]
100%|██████████| 652/652 [00:15<00:00, 41.97it/s]
100%|██████████| 652/652 [00:15<00:00, 42.94it/s]
100%|██████████| 652/652 [00:17<00:00, 38.04it/s]
  9%|▉         | 61/652 [00:01<00:14, 39.87it/s]Exception ignored in: Exception ignored in: Exception ignored in: <function tqdm.__del__ at 0x7f3a2f164b80><function tqdm.__del__ at 0x7f3a2f164b80><function tqdm.__del__ at 0x7f3a2f164b80>


Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/venv/main/lib/python3.12/site-packages/tqdm/std.py", line 1148, in __del__
  File "/venv/main/lib/python3.12/site-packages/tqdm/std.py", line 1148, in __del__
  File "/venv/main/lib/python3.12/site-packages/tqdm/std.py", line 1148, in __del__
            self.close()self.close()self.close()


  File "/venv/main/lib/python3.12/site-packages/tqdm/notebook.py", line 273, in close
  File "/venv/ma

In [None]:
torch.save(model.state_dict(), "/workspace/llama3.2-1b-sae/sae.pth")

In [None]:
sae2 = Sae(d_in, d_sae)
sae2.load_state_dict(model.state_dict())

<All keys matched successfully>

In [None]:
sa