# Entrenamiento de un SAE sobre las salidad y activaciones de la MLP intermedia de llama3.2 1B

## Costo de entrenamiento

Generalmente, el costo computacional de un LLM está dominado por la evaluación
de sus MLPs (referencia a el seminario en una universidad de un ex empleado
de anthropic)

Una aprocimación simple del costo de entrenamiento de un modelo es

$$
  6ND 
$$
esto en términos de FLOPs (operaciones de punto flotante).
(citar a chinchilla scaling laws)

donde $N$ es el número de parámetros y $D$ la cantidad de muestras en el
conjunto de entrenamiento.

Esto se debe a que, generalmente, cada parámetro actua en una multiplicaciónl
y en una suma de punto flotante, dandonos un costo de $2ND$ tan solo en el
forward pass. Tipicamente, el costo de el backward pass es el doble del forward
pass, haciendo que su costo sea $4ND$. Sumando tenemos el resultado previamente
mencionado.

Sea $N_l$ el número de parámetros de llama.
Como nosotros vamos a entrenar un autoencoder sobre un MLP a la mitad de llama,
solo necesitamos evaluar esa primera mitad. Además, no correremos el backward
pass sobre los parámetros de llama, pues no buscamos modificarlos, es decir, los
mantendrémos fijos. Por esto, tenemos que el FLOPs realizados por tal mitad del
modelo llama es

$$
  N_l D
$$

En cuanto al SAE, solo considerando el costo de aplicar sus matrices, tenemos

$$
  6 (2d_\text{in}d_\text{sae})D
$$

En el caso de gemmascope, entre todos los SAEs que entrenaron, los más pequeños
entrenados en las salidas de las MLPs, se entrenaron en 4 mil millones de
vectores de activaciones, con la dimencion de los vectores en el stream
recidual (y por lo tanto, de la salida de las capas MLP), es 2048, con
$d_\text{sae} = 2028 * 8$.

Deseamos encontrar hiperparámetros para entrenar SAEs, para esto:
- Usamos una primera aproximación razonable modificando los hiperparámetros para
  los autoencoders más pequeños entrenados en salidas de las MLPs de gemma 2
  2 B
- Ajustamos una power law en base a 2 entrenamientos de SAEs más pequeñas,
  usando el mismo learning rate.
- Ajustamos una power law para el learning rate con los hiperparámetos optimos
  que estimó el paso anterior.

Si ignoramos la relación posicional de los tokens y asumimos una distribución
uniforme, tenemos que la entropía por token es

$$
  \log_2 (\text{tamaño del vocabulario})
$$
ya que el vocabulario de gemma2 2B es 256000 y el de llama es 128000, obtenemos
que el número de tokens equivalente sería 4.2 mil millones. Ya que nuestro
modelo es la mitad de tamaño de gemma2 2B, una primera cantidad de datos
razonable para entrenar nuesto sae más grande es $2.1B$

En el caso de llama3.2 1B, eso resultaría en

$$
  N_l D = (2.1 \times 10^9)(1.2 \times 10^9) \approx 2.5 \times 10^{18}
$$
Una RTX 4090 puede realisar cada segundo un máximo de $165 \times 10^{12}$
operaciones con tensores de 16 bits y acumulador de 32 bits (referencia
al reporte técnico v1.0.1), luego, estimamos 4.2 horas de entrenamiento
tan solo considerando la computación en el modelo llama.

Ahora, para estimar las horas-RTX4090 para el autoencoder, en el caso de
entrenarlo en la salida de la MLP intermedia, tendríamos

$$
    6(2.1 \times 10^9)(2048^2)(8)(2) = 8.5 \times 10^{17}
$$

# El modelo

In [1]:
import torch
from torch import nn
import jaxtyping
import dataclasses
from tqdm import tqdm
from torch.nn import functional as F
from torch.linalg import vector_norm
import math
import datasets
import numpy as np

In [2]:
class Step(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, threshold):
        ctx.save_for_backward(x, threshold)
        return (x > threshold).to(x.dtype)

    @staticmethod
    def backward(ctx, grad_output):
        bandwidth = 0.001
        x, threshold = ctx.saved_tensors

        grad_threshold = torch.where(
            abs(F.relu(x) - threshold) < bandwidth/2,
            -1.0/bandwidth, 0)
        
        return torch.zeros_like(x), grad_threshold * grad_output

In [3]:
dtype = torch.float32

In [4]:
class Sae(nn.Module):
    def __init__(self, d_in, d_sae, use_pre_enc_bias=True, **kwargs):
        super().__init__(**kwargs)
        self.enc = nn.Linear(d_in, d_sae, dtype=dtype)
        self.dec = nn.Linear(d_sae, d_in, dtype=dtype)
        with torch.no_grad():
            # normalize each of the d_sae dictonary vectors
            self.dec.weight /= vector_norm(self.dec.weight, dim=0, keepdim=True)
            self.enc.weight.copy_(self.dec.weight.clone().t())
            self.enc.bias.copy_(torch.zeros_like(self.enc.bias))
            self.dec.bias.copy_(torch.zeros_like(self.dec.bias))
        self.log_threshold = nn.Parameter(
            torch.log(torch.full((d_sae,), 0.001, dtype=dtype)))
        self.use_pre_enc_bias = use_pre_enc_bias
        def project_out_parallel_grad(dim, tensor):
            @torch.no_grad
            def hook(grad_in):
                # norm along dim=dim of the tensor is assumed to be 1 as we
                # are going to normalize it after every grad update
                dot = (tensor * grad_in).sum(dim=dim, keepdim=True)
                return grad_in - dot * tensor
            return hook

        self.dec.weight.register_hook(
            project_out_parallel_grad(0, self.dec.weight))
                

    def forward(self,
        x,
        return_mask=False,
        return_l0=True,
        return_reconstruction_loss=True,
    ):
        "We compute this much here so that compile() can do its magic"
        # as per train_gpt2.py on karpathy's llm.c repo, there are performance
        # reasons not to return stuff
        d = {}
        original_input = x
        if self.use_pre_enc_bias:
            x = x - self.dec.bias
        
        x = self.enc(x)
        threshold = torch.exp(self.log_threshold)
        s = Step.apply(x, threshold)
        if return_mask:
            d['mask'] = s
        if return_l0:
            d['l0'] = s.float().mean(0).sum(-1)
        if not return_reconstruction_loss:
            return d
        x = x*s
        x = self.dec(x)
        # print(((((x - original_input).float())**2) == 0).sum().to(int))

        # d['reconstruction'] = ((x - original_input)**2).mean(0, dtype=torch.float32).sum()
        d['reconstruction'] = ((x.float() - original_input.float()).pow(2)).mean(0).sum()
        # use pow(2)?
        # d['reconstruction'] = ((x - original_input)**2).sum(1, dtype=torch.float32).mean()
        # d['reconstruction'] = (torch.linalg.vector_norm(x - original_input, dim=1, dtype=torch.float32)**2).mean()

        return d

In [5]:
# d_in = 1024
# d_sae = d_in*8
# sae = Sae(d_in, d_sae)
# sae.to("cuda")
# sae.compile()

In [6]:
# b = 512
# x = torch.randn(b, d_in)
# x = x.to(torch.float16).to("cuda")
# # x /= x.norm(dim=1, keepdim=True).mean(dim=1, keepdim=True)
# x /= torch.linalg.vector_norm(x, dim=1, keepdim=True).mean(dim=1, keepdim=True)
# x = x.to(torch.float16)
# d = sae(x)
# d['reconstruction'].detach().cpu().numpy()


In [7]:
# collection = []
# for i in range(20):
#     if not (i % 10):
#         print(i)
#     x = torch.randn(b, d_in)
#     # x /= x.norm(dim=1, keepdim=True).mean(dim=1, keepdim=True)
#     x /= torch.linalg.vector_norm(x, dim=1, keepdim=True).mean(dim=1, keepdim=True)
#     x = x.half()
#     sae = Sae(d_in, d_sae)
#     if not sae.enc.weight.detach().isfinite().all():
#         print("enc blew up")
#     if not sae.dec.weight.detach().isfinite().all():
#         print("dec blew up")
#     d = sae(x)
#     l0, reconstruction = d['l0'].detach(), d['reconstruction'].detach()
#     if not l0.isfinite().all():
#         print("l0 blew up")
#     if not reconstruction.isfinite().all():
#         print("reconstruction blew up")
#     collection.append((l0, reconstruction))

In [8]:
def cosine_schedule_with_warmup(
    current_step: int,
    warmup_steps: int,
    total_steps: int
    ):
    if current_step < warmup_steps:
        lr =  (1 + current_step) / warmup_steps
        # print(lr)
        return lr
    progress = (current_step - warmup_steps) / (total_steps - warmup_steps)
    lr =  0.5 * (1 + math.cos(math.pi * progress))
    # print(lr)
    return lr

In [9]:
def sparsity_schedule(step, warmup_steps, max_sparsity_coeff):
    if step >= warmup_steps:
        return max_sparsity_coeff
    return step+1 / warmup_steps

In [10]:
def fake_train_loader(batch, d_in, total_steps):
    for _ in range(total_steps):
        x = torch.randn(batch, d_in)
        x = torch.randn(batch, d_in)
        x /= vector_norm(x, dim=1, keepdim=True)
        yield x

In [11]:
ds = datasets.load_dataset("naraca/mi-dataset-activaciones-llama3_2")

In [12]:
ds

DatasetDict({
    train: Dataset({
        features: ['activacion'],
        num_rows: 1000
    })
})

In [13]:
train_ds = ds['train']
train_ds.set_format(type='numpy')

In [14]:
to_stack = []
new_batch_size = 8192//2
# we are guaranteed these batches have <= 8192 in ax 0
current_batch_size = 0
for i, batch in enumerate(tqdm(train_ds['activacion'])):
    batch = batch.astype(np.float16)
    if current_batch_size + batch.shape[0] < new_batch_size:
        to_stack.append(batch)
        current_batch_size += batch.shape[0]
        continue
    to_stack.append(batch[:new_batch_size - current_batch_size])
    new_batch = np.concat(to_stack, axis=0)
    np.save(f"/workspace/data/slice{i}.npy", new_batch)
    to_stack.clear()
    to_stack.append(batch[new_batch_size - current_batch_size:])
    current_batch_size = batch.shape[0] + current_batch_size - new_batch_size

    # print(current_batch_size)
    # if i > 50:
    #     break

100%|██████████| 1000/1000 [00:06<00:00, 165.67it/s]


In [15]:
import pathlib
from torch.utils.data import Dataset, DataLoader

class NPYFolder(Dataset):
    def __init__(self, root):
        self.root = pathlib.Path(root)
        self.npys = sorted(self.root.glob('*.npy'))
    def __len__(self):
        return len(self.npys)
    def __getitem__(self, idx):
        return torch.from_numpy(np.load(self.npys[idx]))



In [16]:

train_ds = NPYFolder('/workspace/data/')
train_dl = DataLoader(train_ds)

In [17]:
len(train_ds)

163

In [18]:
# mean_norm = 
for x in train_ds:
    norm = torch.linalg.norm(x, dim=1).mean()
    print(norm.item())

3.408203125
3.2890625
3.390625
3.388671875
3.453125
3.44921875
3.4453125
3.3828125
3.375
3.455078125
3.390625
3.458984375
3.40625
3.42578125
3.435546875
3.44921875
3.47265625
3.404296875
3.40234375
3.4296875
3.42578125
3.416015625
3.400390625
3.384765625
3.41796875
3.3671875
3.451171875
3.427734375
3.435546875
3.41796875
3.4609375
3.43359375
3.388671875
3.447265625
3.396484375
3.32421875
3.3515625
3.41796875
3.427734375
3.328125
3.447265625
3.46484375
3.4453125
3.4453125
3.375
3.453125
3.46484375
3.423828125
3.40234375
3.421875
3.361328125
3.375
3.439453125
3.4375
3.419921875
3.455078125
3.392578125
3.453125
3.455078125
3.4140625
3.453125
3.48828125
3.412109375
3.466796875
3.435546875
3.380859375
3.3984375
3.4140625
3.4453125
3.44921875
3.396484375
3.421875
3.443359375
3.392578125
3.42578125
3.453125
3.421875
3.404296875
3.443359375
3.515625
3.400390625
3.416015625
3.392578125
3.40234375
3.361328125
3.447265625
3.45703125
3.44921875
3.39453125
3.380859375
3.341796875
3.431640625
3.4433

In [None]:
steps = 2**16
max_lr = 7e-5
d_in = 2048
d_sae = 2048*8
model = Sae(d_in, d_sae)
model.to('cuda')
model.compile()
warmup_steps=2000
sparcity_warmup_steps=500
total_steps=4000 #for now
batch = 1024
optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, betas=(0,0.999))
max_sparsity_coeff = 6000

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda step: cosine_schedule_with_warmup(step, warmup_steps, total_steps)
)
# train_ds = fake_train_loader(batch, d_in, total_steps)
total_step = 0
for epoch in range(100):
    for step, x in enumerate(tqdm(train_ds)):
        x /= 3.4
        x = x.to(dtype).to("cuda")
        optimizer.zero_grad()
        d = model(x)
        reconstruction_loss, l0 = d['reconstruction'], d['l0']
        sparsity_coefficient = sparsity_schedule(total_step, sparcity_warmup_steps, max_sparsity_coeff)
        loss = reconstruction_loss + sparsity_coefficient * l0
        # log losses, compute stats, etc
        grad = loss.backward()
        # norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        # metrics
        if (total_step % 500) == 0:
            with torch.no_grad():
                # print metrics
                print(f"reconstruction={reconstruction_loss.item()}")
                print(f"l0={l0.item()}")
                # print(f"norm={norm.item()}")
                print(f"{sparsity_coefficient=}")
        optimizer.step()
        # TODO: sparsity_coefficient scheduler
        # print(scheduler.get_lr())
        scheduler.step()

        # normalize
        with torch.no_grad():
            wdecnorm = vector_norm(model.dec.weight, dim=0, keepdim=True)
            model.dec.weight /= wdecnorm
    # print(f"epoch loss: {loss.detach().item()}")
        total_step +=1



  2%|▏         | 4/163 [00:00<00:10, 15.57it/s]

reconstruction=13.014205932617188
l0=7876.583984375
sparsity_coefficient=0.002


100%|██████████| 163/163 [00:10<00:00, 14.98it/s]
100%|██████████| 163/163 [00:10<00:00, 14.86it/s]
100%|██████████| 163/163 [00:10<00:00, 15.55it/s]
 10%|█         | 17/163 [00:00<00:06, 21.81it/s]

reconstruction=1.788827657699585
l0=3777.816162109375
sparsity_coefficient=6000


100%|██████████| 163/163 [00:07<00:00, 20.60it/s]
100%|██████████| 163/163 [00:07<00:00, 21.32it/s]
 82%|████████▏ | 134/163 [00:06<00:01, 22.31it/s]


KeyboardInterrupt: 

In [None]:
model.log_threshold

Parameter containing:
tensor([nan, nan, nan,  ..., nan, nan, nan], device='cuda:0',
       dtype=torch.float16, requires_grad=True)

In [None]:
model.dec.weight

Parameter containing:
tensor([[-0.0239, -0.0280, -0.0218,  ...,  0.0366, -0.0213, -0.0363],
        [ 0.0276, -0.0075,  0.0171,  ..., -0.0142,  0.0292,  0.0326],
        [ 0.0206,  0.0087, -0.0327,  ..., -0.0137, -0.0060, -0.0237],
        ...,
        [ 0.0017,  0.0329,  0.0052,  ...,  0.0046,  0.0026,  0.0004],
        [-0.0197, -0.0227, -0.0090,  ...,  0.0057, -0.0132, -0.0104],
        [ 0.0386, -0.0066,  0.0127,  ..., -0.0363,  0.0285,  0.0007]],
       device='cuda:0', requires_grad=True)

In [None]:
# new_batch_size = 8192
# old_batches_to_stack = []
# left_over = np.ndarray([[]], dtype=np.float16)
# for i, old_batch in enumerate(tqdm(train_ds['activacion'])):
#     # at this point we should have 0 or 1 items in the to_stack list
#     # if we have then we take data from it then from old batches
#     # construct as much new batches as possible:
#     old_batches_sliced = np.array_split(old_batch, new_batch_size)
#     for old_batch_slice in old_batches_sliced[:-1]:
#         new_batch = np.concatenate([
#             left_over,
#             old_batch_slice[left_over.shape[0]:]
#         ])
#         # npy save
#         left_over = old_batch_slice[:left_over.shape[0]]
#     # check if it's possible to contruct a new batch
#     if left_over.shape[0] + old_batches_sliced[-1].shape[0] >= new_batch_size:
#         new_batch = np.concatenate([
#             left_over,
#             old_batches_sliced[-1][left_over.shape[0]:]
#             ])
#         left_over = old_batches_sliced[-1][:left_over.shape[0]]

SyntaxError: invalid syntax. Perhaps you forgot a comma? (1002483432.py, line 13)

In [None]:
# new_batch_size = 8192
# old_batches_to_stack = []
# leftover = None
# # invariant at all lines: leftover.shape[0] < new_batch_size
# i = 0
# total_tokens = 0
# for old_batch in tqdm(train_ds['activacion']):
#     total_tokens += old_batch.shape[0]
#     # construct as much new batches as possible out of the old batch:
#     old_batches_sliced = np.array_split(old_batch,
#                                         old_batch.shape[0] // new_batch_size
#                                         + 1 if old_batch.shape[0] % new_batch_size else 0)
#     # print(len(old_batches_sliced))
#     # break
#     for old_batch_slice in old_batches_sliced:
#         if old_batch_slice + 0 if leftover is None else leftover.shape[0] < new_batch_size:
#             break
#         # all of old_batch_slice is enough to make at least 1 new_batch:
#         # collect all its activations
#         new_batch = old_batch_slice if leftover is None else np.concatenate([
#             leftover,
#             old_batch_slice[:leftover.shape[0]]
#         ])
#         assert new_batch.shape[0] == new_batch_size
#         # np.save(f"/workspace/data/slice{i}.npy", new_batch)
#         print(f"saved {i}")
#         i += 1
#         leftover = old_batch_slice if leftover is None else old_batch_slice[leftover.shape[0]:]
#         assert leftover is None or leftover.shape[0] < new_batch_size
#     assert old_batches_sliced[-1].shape[0] <= new_batch_size
#     # check if it's possible to contruct a new batch
#     # notice it would be at most 1 
#     if (0 if leftover is None else leftover.shape[0]) + old_batches_sliced[-1].shape[0] >= new_batch_size:
#         # print('a', (0 if leftover is None else leftover.shape[0]) + old_batches_sliced[-1].shape[0])
#         new_batch = old_batches_sliced[-1] if leftover is None else np.concatenate([
#             leftover,
#             old_batches_sliced[-1][:new_batch_size - leftover.shape[0]]
#             ])
#         assert new_batch.shape[0] == new_batch_size
#         leftover = old_batches_sliced[-1] if leftover is None else old_batches_sliced[-1][new_batch_size - leftover.shape[0]:]
#         assert leftover is None or leftover.shape[0] < new_batch_size
#         # np.save(f"/workspace/data/slice{i}.npy", new_batch)
#         print(f"saved {i}")
#         i += 1
#     else:
#         leftover = old_batches_sliced[-1]

#     # if leftover is not None:
#     #     print(leftover.shape[0])
#     # else:
#     #     print(0)
#     if i >= 20:
#         print('exiting')
#         break

# if leftover is not None:
#     # np.save(f"/workspace/data/slice{i}.npy", leftover)
#     i += 1
#     print(f"saved {i}")


# print(total_tokens)

100%|██████████| 1000/1000 [00:00<00:00, 8006.02it/s]

saved 0
saved 1
saved 2
saved 3
saved 4
saved 5
saved 6
saved 7
saved 8
saved 9
saved 11
668501





In [None]:
len(old_batches_sliced)

8192

In [None]:
a = train_ds.select(range(2))
a.set_format(type='torch')

In [None]:
t = a['activacion'][1]
t = t.half().numpy()
np.save("/workspace/data/t.npy", t)

In [None]:
t2 = torch.from_numpy(np.load("/workspace/data/t.npy"))
t2

tensor([[-0.0185, -0.1119,  0.0157,  ..., -0.1203, -0.0603,  0.1016],
        [-0.0088, -0.0302, -0.1082,  ..., -0.0330,  0.0157, -0.0439],
        [ 0.0186,  0.0679, -0.0203,  ...,  0.0202,  0.1085,  0.1808],
        ...,
        [ 0.1361, -0.0654,  0.0571,  ...,  0.0568, -0.0781,  0.0022],
        [-0.0706, -0.1146,  0.0350,  ...,  0.0627,  0.0798,  0.0478],
        [ 0.0402, -0.0783,  0.0140,  ..., -0.0611,  0.1315,  0.0892]],
       dtype=torch.float16)

In [None]:
t2