# Entrenamiento de un SAE sobre las salidad y activaciones de la MLP intermedia de llama3.2 1B

## Costo de entrenamiento

Generalmente, el costo computacional de un LLM está dominado por la evaluación
de sus MLPs (referencia a el seminario en una universidad de un ex empleado
de anthropic)

Una aprocimación simple del costo de entrenamiento de un modelo es

$$
  6ND 
$$
esto en términos de FLOPs (operaciones de punto flotante).
(citar a chinchilla scaling laws)

donde $N$ es el número de parámetros y $D$ la cantidad de muestras en el
conjunto de entrenamiento.

Esto se debe a que, generalmente, cada parámetro actua en una multiplicaciónl
y en una suma de punto flotante, dandonos un costo de $2ND$ tan solo en el
forward pass. Tipicamente, el costo de el backward pass es el doble del forward
pass, haciendo que su costo sea $4ND$. Sumando tenemos el resultado previamente
mencionado.

Sea $N_l$ el número de parámetros de llama.
Como nosotros vamos a entrenar un autoencoder sobre un MLP a la mitad de llama,
solo necesitamos evaluar esa primera mitad. Además, no correremos el backward
pass sobre los parámetros de llama, pues no buscamos modificarlos, es decir, los
mantendrémos fijos. Por esto, tenemos que el FLOPs realizados por tal mitad del
modelo llama es

$$
  N_l D
$$

En cuanto al SAE, solo considerando el costo de aplicar sus matrices, tenemos

$$
  6 (2d_\text{in}d_\text{sae})D
$$

En el caso de gemmascope, entre todos los SAEs que entrenaron, los más pequeños
entrenados en las salidas de las MLPs, se entrenaron en 4 mil millones de
vectores de activaciones, con la dimencion de los vectores en el stream
recidual (y por lo tanto, de la salida de las capas MLP), es 2048, con
$d_\text{sae} = 2028 * 8$.

Deseamos encontrar hiperparámetros para entrenar SAEs, para esto:
- Usamos una primera aproximación razonable modificando los hiperparámetros para
  los autoencoders más pequeños entrenados en salidas de las MLPs de gemma 2
  2 B
- Ajustamos una power law en base a 2 entrenamientos de SAEs más pequeñas,
  usando el mismo learning rate.
- Ajustamos una power law para el learning rate con los hiperparámetos optimos
  que estimó el paso anterior.

Si ignoramos la relación posicional de los tokens y asumimos una distribución
uniforme, tenemos que la entropía por token es

$$
  \log_2 (\text{tamaño del vocabulario})
$$
ya que el vocabulario de gemma2 2B es 256000 y el de llama es 128000, obtenemos
que el número de tokens equivalente sería 4.2 mil millones. Ya que nuestro
modelo es la mitad de tamaño de gemma2 2B, una primera cantidad de datos
razonable para entrenar nuesto sae más grande es $2.1B$

En el caso de llama3.2 1B, eso resultaría en

$$
  N_l D = (2.1 \times 10^9)(1.2 \times 10^9) \approx 2.5 \times 10^{18}
$$
Una RTX 4090 puede realisar cada segundo un máximo de $165 \times 10^{12}$
operaciones con tensores de 16 bits y acumulador de 32 bits (referencia
al reporte técnico v1.0.1), luego, estimamos 4.2 horas de entrenamiento
tan solo considerando la computación en el modelo llama.

Ahora, para estimar las horas-RTX4090 para el autoencoder, en el caso de
entrenarlo en la salida de la MLP intermedia, tendríamos

$$
    6(2.1 \times 10^9)(2048^2)(8)(2) = 8.5 \times 10^{17}
$$

# El modelo

In [45]:
import torch
from torch import nn
import jaxtyping
import dataclasses
from tqdm import tqdm
from torch.nn import functional as F
from torch.linalg import vector_norm
import math
import datasets
import numpy as np
from torch.utils.data import Dataset, DataLoader

In [None]:
dtype = torch.float32
device = "cuda" if torch.cuda.is_available else "cpu"

In [2]:
class Step(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, threshold):
        ctx.save_for_backward(x, threshold)
        return (x > threshold).to(x.dtype)

    @staticmethod
    def backward(ctx, grad_output):
        bandwidth = 0.001
        x, threshold = ctx.saved_tensors

        grad_threshold = torch.where(
            abs(F.relu(x) - threshold) < bandwidth/2,
            -1.0/bandwidth, 0)
        
        return torch.zeros_like(x), grad_threshold * grad_output

In [None]:
class Sae(nn.Module):
    def __init__(self, d_in, d_sae, use_pre_enc_bias=True, **kwargs):
        super().__init__(**kwargs)
        self.enc = nn.Linear(d_in, d_sae, dtype=dtype)
        self.dec = nn.Linear(d_sae, d_in, dtype=dtype)
        with torch.no_grad():
            # normalize each of the d_sae dictonary vectors
            self.dec.weight /= vector_norm(self.dec.weight, dim=0, keepdim=True)
            self.enc.weight.copy_(self.dec.weight.clone().t())
            self.enc.bias.copy_(torch.zeros_like(self.enc.bias))
            self.dec.bias.copy_(torch.zeros_like(self.dec.bias))
        self.log_threshold = nn.Parameter(
            torch.log(torch.full((d_sae,), 0.001, dtype=dtype)))
        self.use_pre_enc_bias = use_pre_enc_bias
        def project_out_parallel_grad(dim, tensor):
            @torch.no_grad
            def hook(grad_in):
                # norm along dim=dim of the tensor is assumed to be 1 as we
                # are going to normalize it after every grad update
                dot = (tensor * grad_in).sum(dim=dim, keepdim=True)
                return grad_in - dot * tensor
            return hook

        self.dec.weight.register_hook(
            project_out_parallel_grad(0, self.dec.weight))
                

    def forward(self,
        x,
        return_mask=False,
        return_l0=True,
        return_reconstruction_loss=True,
    ):
        "We compute this much here so that compile() can do its magic"
        # as per train_gpt2.py on karpathy's llm.c repo, there are performance
        # reasons not to return stuff
        d = {}
        original_input = x
        if self.use_pre_enc_bias:
            x = x - self.dec.bias
        
        x = self.enc(x)
        threshold = torch.exp(self.log_threshold)
        s = Step.apply(x, threshold)
        if return_mask:
            d['mask'] = s
        if return_l0:
            d['l0'] = s.float().mean(0).sum(-1)
        if not return_reconstruction_loss:
            return d
        x = x*s
        x = self.dec(x)

        d['reconstruction'] = ((x - original_input).pow(2)).mean(0).sum()

        return d

In [None]:
def cosine_schedule_with_warmup(
    current_step: int,
    warmup_steps: int,
    total_steps: int
    ):
    if current_step < warmup_steps:
        lr =  (1 + current_step) / warmup_steps
        return lr
    progress = (current_step - warmup_steps) / (total_steps - warmup_steps)
    lr =  0.5 * (1 + math.cos(math.pi * progress))
    return lr

In [10]:
def sparsity_schedule(step, warmup_steps, max_sparsity_coeff):
    if step >= warmup_steps:
        return max_sparsity_coeff
    return max_sparsity_coeff*((step+1) / warmup_steps)

In [44]:
from datasets import load_dataset

# Load the dataset
ds = load_dataset('mech-interp-uam/llama-mlp8-outputs')
ds.set_format('numpy')

# Check the dataset structure
print(f"Dataset loaded successfully: {len(ds['train'])} examples")
print(f"Features: {ds['train'].features}")
print(f"First example shape: {ds['train'][0]['activations'].shape}")

Dataset loaded successfully: 668501 examples
Features: {'activations': Sequence(feature=Value(dtype='float16', id=None), length=2048, id=None)}
First example shape: (2048,)


In [None]:
import torch

class ActivationsDataset(Dataset):
    def __init__(self, hf_dataset):
        self.data = hf_dataset

    def __getitem__(self, idx):
        # Return as float16, on modern GPUs conversion from float 16 to 32 is
        # free compared to matmults or so I was told
        return torch.tensor(self.data[idx]['activations'])

    def __len__(self):
        return len(self.data)


In [50]:
batch = 1024
dataset = ActivationsDataset(ds['train'])
dataloader = DataLoader(dataset, batch_size=batch, shuffle=True, num_workers=4)

In [51]:
len(dataloader)

653

In [52]:
steps = 2**16
max_lr = 7e-5
d_in = 2048
d_sae = 2048*8
model = Sae(d_in, d_sae)
model.to('cuda')
#model.compile()
warmup_steps=2000
sparcity_warmup_steps=10000
total_steps=32000 #for now
optimizer = torch.optim.Adam([
    {"params": model.enc.parameters(), "lr":max_lr, "betas":(0.0,0.999)},
    {"params": model.dec.parameters(), "lr":max_lr, "betas":(0.0,0.999)},
    {"params": model.log_threshold, "lr":3.5*max_lr, "betas":(0.9,0.999)},
])
max_sparsity_coeff = 0.0005

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda step: cosine_schedule_with_warmup(step, warmup_steps, total_steps)
)
# train_ds = fake_train_loader(batch, d_in, total_steps)
total_step = 0
should_break = False
for epoch in range(1000):
    for step, x in enumerate(tqdm(dataloader)):
        x /= 3.4 # this is supposed to be the expected norm
        x = x.to(dtype).to("cuda")
        optimizer.zero_grad()
        d = model(x)
        reconstruction_loss, l0 = d['reconstruction'], d['l0']
        sparsity_coefficient = sparsity_schedule(total_step, sparcity_warmup_steps, max_sparsity_coeff)
        loss = reconstruction_loss + sparsity_coefficient * l0
        # log losses, compute stats, etc
        grad = loss.backward()
        # norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        # metrics
        if (total_step % 500) == 0:
            with torch.no_grad():
                # print metrics
                print(f"reconstruction={reconstruction_loss.item()}")
                print(f"l0={l0.item()}")
                # print(f"norm={norm.item()}")
                print(f"{sparsity_coefficient=}")
        optimizer.step()
        # TODO: sparsity_coefficient scheduler
        # print(scheduler.get_lr())
        scheduler.step()

        # normalize
        with torch.no_grad():
            wdecnorm = vector_norm(model.dec.weight, dim=0, keepdim=True)
            model.dec.weight /= wdecnorm
    # print(f"epoch loss: {loss.detach().item()}")
        total_step +=1
        if total_step > total_steps:
            should_break = True
            break
    if should_break:
        break



  0%|          | 1/653 [00:00<04:30,  2.41it/s]

reconstruction=13.11782169342041
l0=7894.46875
sparsity_coefficient=5.0000000000000004e-08


 77%|███████▋  | 505/653 [00:25<00:07, 20.26it/s]

reconstruction=0.43382197618484497
l0=1923.9560546875
sparsity_coefficient=2.505e-05


100%|██████████| 653/653 [00:33<00:00, 19.71it/s]
 53%|█████▎    | 349/653 [00:17<00:15, 20.08it/s]

reconstruction=0.3316606879234314
l0=1460.61328125
sparsity_coefficient=5.005e-05


100%|██████████| 653/653 [00:33<00:00, 19.71it/s]
 30%|███       | 197/653 [00:10<00:22, 19.85it/s]

reconstruction=0.2898387908935547
l0=1476.85546875
sparsity_coefficient=7.505000000000001e-05


100%|██████████| 653/653 [00:33<00:00, 19.62it/s]
  7%|▋         | 45/653 [00:02<00:30, 19.64it/s]

reconstruction=0.2538369297981262
l0=1487.107421875
sparsity_coefficient=0.00010005


 84%|████████▎ | 546/653 [00:27<00:05, 19.96it/s]

reconstruction=0.22999128699302673
l0=1495.8603515625
sparsity_coefficient=0.00012505


100%|██████████| 653/653 [00:33<00:00, 19.49it/s]
 60%|█████▉    | 391/653 [00:20<00:14, 18.44it/s]

reconstruction=0.2109168916940689
l0=1480.689453125
sparsity_coefficient=0.00015005


100%|██████████| 653/653 [00:34<00:00, 19.00it/s]
 36%|███▋      | 237/653 [00:12<00:20, 20.21it/s]

reconstruction=0.19227583706378937
l0=1490.029296875
sparsity_coefficient=0.00017505000000000003


100%|██████████| 653/653 [00:34<00:00, 19.04it/s]
 13%|█▎        | 85/653 [00:04<00:31, 18.18it/s]

reconstruction=0.1758834272623062
l0=1499.6142578125
sparsity_coefficient=0.00020005


 89%|████████▉ | 584/653 [00:32<00:04, 14.76it/s]

reconstruction=0.16457249224185944
l0=1477.5703125
sparsity_coefficient=0.00022505


100%|██████████| 653/653 [00:36<00:00, 17.83it/s]
 66%|██████▌   | 430/653 [00:24<00:13, 16.25it/s]

reconstruction=0.15042296051979065
l0=1470.66015625
sparsity_coefficient=0.00025005


100%|██████████| 653/653 [00:36<00:00, 17.86it/s]
 43%|████▎     | 281/653 [00:15<00:21, 17.02it/s]

reconstruction=0.1397533267736435
l0=1445.5048828125
sparsity_coefficient=0.00027505000000000004


100%|██████████| 653/653 [00:35<00:00, 18.53it/s]
 19%|█▉        | 125/653 [00:06<00:27, 19.24it/s]

reconstruction=0.1297432780265808
l0=1428.697265625
sparsity_coefficient=0.00030005


 96%|█████████▌| 626/653 [00:34<00:01, 20.21it/s]

reconstruction=0.12482182681560516
l0=1377.8701171875
sparsity_coefficient=0.00032505


100%|██████████| 653/653 [00:36<00:00, 17.97it/s]
 73%|███████▎  | 474/653 [00:25<00:08, 19.95it/s]

reconstruction=0.12232045829296112
l0=1307.60546875
sparsity_coefficient=0.00035004999999999997


100%|██████████| 653/653 [00:34<00:00, 18.85it/s]
 49%|████▉     | 322/653 [00:16<00:15, 21.14it/s]

reconstruction=0.1226264238357544
l0=1221.4765625
sparsity_coefficient=0.00037505


100%|██████████| 653/653 [00:32<00:00, 19.80it/s]
 26%|██▌       | 169/653 [00:08<00:23, 20.20it/s]

reconstruction=0.12564939260482788
l0=1112.89453125
sparsity_coefficient=0.00040005000000000005


100%|██████████| 653/653 [00:33<00:00, 19.59it/s]
  2%|▏         | 13/653 [00:01<00:40, 15.82it/s]

reconstruction=0.13368447124958038
l0=989.8837890625
sparsity_coefficient=0.00042505


 79%|███████▊  | 513/653 [00:26<00:06, 20.18it/s]

reconstruction=0.15467005968093872
l0=873.8203125
sparsity_coefficient=0.00045005


100%|██████████| 653/653 [00:33<00:00, 19.73it/s]
 55%|█████▌    | 362/653 [00:18<00:14, 20.32it/s]

reconstruction=0.16833460330963135
l0=744.994140625
sparsity_coefficient=0.00047505


100%|██████████| 653/653 [00:32<00:00, 19.79it/s]
 32%|███▏      | 210/653 [00:10<00:21, 20.38it/s]

reconstruction=0.18350555002689362
l0=657.658203125
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.65it/s]
  9%|▊         | 57/653 [00:03<00:29, 19.89it/s]

reconstruction=0.19543102383613586
l0=583.380859375
sparsity_coefficient=0.0005


 85%|████████▍ | 554/653 [00:28<00:05, 18.69it/s]

reconstruction=0.21201559901237488
l0=517.6044921875
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.74it/s]
 61%|██████▏   | 401/653 [00:20<00:12, 20.05it/s]

reconstruction=0.21768631041049957
l0=472.861328125
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.70it/s]
 38%|███▊      | 250/653 [00:12<00:20, 19.47it/s]

reconstruction=0.22179633378982544
l0=434.7255859375
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.65it/s]
 15%|█▍        | 97/653 [00:05<00:27, 20.39it/s]

reconstruction=0.22328394651412964
l0=401.330078125
sparsity_coefficient=0.0005


 91%|█████████▏| 596/653 [00:30<00:02, 19.79it/s]

reconstruction=0.23193836212158203
l0=368.330078125
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.61it/s]
 68%|██████▊   | 445/653 [00:22<00:10, 19.97it/s]

reconstruction=0.23430505394935608
l0=345.4345703125
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.66it/s]
 44%|████▍     | 289/653 [00:14<00:18, 19.99it/s]

reconstruction=0.23476551473140717
l0=327.69921875
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.66it/s]
 21%|██        | 137/653 [00:07<00:26, 19.72it/s]

reconstruction=0.2344397008419037
l0=313.3310546875
sparsity_coefficient=0.0005


 98%|█████████▊| 637/653 [00:32<00:00, 19.89it/s]

reconstruction=0.24456503987312317
l0=300.0029296875
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.65it/s]
 74%|███████▍  | 485/653 [00:24<00:08, 20.03it/s]

reconstruction=0.24218963086605072
l0=288.2880859375
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.61it/s]
 51%|█████     | 333/653 [00:17<00:15, 20.33it/s]

reconstruction=0.24388307332992554
l0=278.9794921875
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.49it/s]
 27%|██▋       | 177/653 [00:09<00:23, 20.07it/s]

reconstruction=0.24185879528522491
l0=268.4248046875
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.64it/s]
  4%|▍         | 25/653 [00:01<00:33, 18.94it/s]

reconstruction=0.23956242203712463
l0=264.9541015625
sparsity_coefficient=0.0005


 81%|████████  | 527/653 [00:26<00:06, 20.61it/s]

reconstruction=0.24742871522903442
l0=255.7783203125
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.74it/s]
 57%|█████▋    | 373/653 [00:19<00:13, 20.47it/s]

reconstruction=0.24588991701602936
l0=248.982421875
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.68it/s]
 34%|███▍      | 221/653 [00:11<00:21, 20.29it/s]

reconstruction=0.24496322870254517
l0=246.4169921875
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.69it/s]
 10%|▉         | 65/653 [00:03<00:29, 19.97it/s]

reconstruction=0.2444174736738205
l0=241.1533203125
sparsity_coefficient=0.0005


 87%|████████▋ | 566/653 [00:28<00:04, 20.02it/s]

reconstruction=0.24732226133346558
l0=235.826171875
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.45it/s]
 63%|██████▎   | 413/653 [00:21<00:12, 19.78it/s]

reconstruction=0.24833106994628906
l0=235.2275390625
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.59it/s]
 40%|███▉      | 261/653 [00:13<00:19, 20.60it/s]

reconstruction=0.2459370642900467
l0=232.3447265625
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.21it/s]
 17%|█▋        | 109/653 [00:06<00:30, 17.96it/s]

reconstruction=0.2458069622516632
l0=228.451171875
sparsity_coefficient=0.0005


 93%|█████████▎| 609/653 [00:33<00:02, 19.41it/s]

reconstruction=0.24929657578468323
l0=224.7158203125
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:35<00:00, 18.18it/s]
 69%|██████▉   | 453/653 [00:25<00:10, 18.58it/s]

reconstruction=0.2464187741279602
l0=223.8583984375
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:35<00:00, 18.16it/s]
 46%|████▌     | 302/653 [00:16<00:18, 18.97it/s]

reconstruction=0.2489900141954422
l0=224.828125
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:36<00:00, 17.89it/s]
 23%|██▎       | 149/653 [00:07<00:26, 19.36it/s]

reconstruction=0.24792447686195374
l0=222.6650390625
sparsity_coefficient=0.0005


100%|█████████▉| 651/653 [00:33<00:00, 21.37it/s]

reconstruction=0.24904105067253113
l0=216.8779296875
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.34it/s]
 76%|███████▌  | 497/653 [00:25<00:07, 20.20it/s]

reconstruction=0.24724061787128448
l0=217.2939453125
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.60it/s]
 52%|█████▏    | 341/653 [00:17<00:15, 19.64it/s]

reconstruction=0.2482089400291443
l0=215.5009765625
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.75it/s]
 29%|██▉       | 189/653 [00:09<00:23, 19.77it/s]

reconstruction=0.24859726428985596
l0=217.4033203125
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.62it/s]
  6%|▌         | 37/653 [00:02<00:31, 19.33it/s]

reconstruction=0.24643957614898682
l0=213.1767578125
sparsity_coefficient=0.0005


 82%|████████▏ | 537/653 [00:27<00:05, 20.45it/s]

reconstruction=0.24896806478500366
l0=213.90234375
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.73it/s]
 59%|█████▉    | 385/653 [00:19<00:13, 20.22it/s]

reconstruction=0.24896764755249023
l0=214.794921875
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.63it/s]
 36%|███▌      | 232/653 [00:11<00:19, 21.71it/s]

reconstruction=0.24893565475940704
l0=211.212890625
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.61it/s]
 12%|█▏        | 77/653 [00:04<00:28, 20.02it/s]

reconstruction=0.24782709777355194
l0=211.2392578125
sparsity_coefficient=0.0005


 88%|████████▊ | 577/653 [00:29<00:03, 20.16it/s]

reconstruction=0.24873071908950806
l0=211.0771484375
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.72it/s]
 65%|██████▌   | 425/653 [00:21<00:11, 19.87it/s]

reconstruction=0.24832455813884735
l0=212.0986328125
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.38it/s]
 42%|████▏     | 273/653 [00:14<00:19, 19.95it/s]

reconstruction=0.24692755937576294
l0=210.6708984375
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.69it/s]
 18%|█▊        | 118/653 [00:06<00:27, 19.49it/s]

reconstruction=0.24846002459526062
l0=209.1455078125
sparsity_coefficient=0.0005


 95%|█████████▍| 619/653 [00:31<00:01, 20.53it/s]

reconstruction=0.24751469492912292
l0=209.36328125
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.64it/s]
 72%|███████▏  | 467/653 [00:23<00:09, 20.08it/s]

reconstruction=0.24772629141807556
l0=211.4912109375
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:33<00:00, 19.76it/s]
 48%|████▊     | 313/653 [00:15<00:16, 20.50it/s]

reconstruction=0.24499791860580444
l0=207.42578125
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:32<00:00, 20.02it/s]
 25%|██▍       | 161/653 [00:09<00:25, 19.54it/s]

reconstruction=0.24641597270965576
l0=205.611328125
sparsity_coefficient=0.0005


100%|██████████| 653/653 [00:34<00:00, 18.97it/s]
  0%|          | 3/653 [00:00<02:25,  4.47it/s]

reconstruction=0.24850615859031677
l0=209.23046875
sparsity_coefficient=0.0005





In [55]:
torch.save(model.state_dict(), "/workspace/llama3.2-1b-sae/sae.pth")