In [179]:
# Upload to huggingface
from huggingface_hub import HfApi, HfFolder
import torch
from gated_sae import GatedSAE
from vanilla_sae import SparseAutoencoder
import plotly.express as px


hf_token = "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"
api = HfApi(token=hf_token)

In [180]:
gated_sae = torch.load('data/sparse_sae.pt', map_location=torch.device('cpu'))

In [181]:
gated_sae

SparseAutoencoder(
  (encoder): Sequential(
    (0): TiedBias()
    (1): ConstrainedUnitNormLinear()
    (2): ReLU()
  )
  (decoder): Sequential(
    (0): ConstrainedUnitNormLinear()
    (1): TiedBias()
  )
)

In [182]:
sae_errors = torch.load('data/sae_errors.pt')
original_z = torch.load('data/original_z.pt')
print(sae_errors.shape, original_z.shape)

torch.Size([143872, 768]) torch.Size([143872, 768])


In [183]:
idx = 0
x = original_z[idx, :].unsqueeze(0)
y = sae_errors[idx, :].unsqueeze(0)

print(f"Original Z act: {x[0][0]:.3f}")
print(f"SAE error: {y[0][0]:.3f}")

sae_pred = original_z - sae_errors
print(f"SAE pred: {sae_pred[idx, 0]:.3f}")

recon, loss = gated_sae(x, y)
print(f"Recon: {recon[0][0]:.3f}") 

Original Z act: -0.080
SAE error: -0.011
SAE pred: -0.068
Recon: -0.050


In [184]:
# Original z line plot
fig = px.line(x=range(0, 768), y=original_z[idx, :].detach().numpy(), title='Original z')
fig.add_scatter(y=sae_pred[idx, :].detach().numpy(), mode='lines', name='SAE Pred')
fig.show()

# Original z line plot
fig = px.line(x=range(0, 768), y=sae_errors[idx, :].detach().numpy(), title='SAE Error')
fig.add_scatter(y=recon.squeeze().detach().numpy(), mode='lines', name='Reconstructed Error')
fig.show()

In [178]:
idx = 0
x = original_z[idx, :].unsqueeze(0)
y = sae_errors[idx, :].unsqueeze(0)
sae_pred = original_z - sae_errors
recon_0, loss = gated_sae(x, y)


idx = 1
x = original_z[idx, :].unsqueeze(0)
y = sae_errors[idx, :].unsqueeze(0)
sae_pred = original_z - sae_errors
recon_1, loss = gated_sae(x, y)
recon_1 += 0.0

# Plot recon_0 and recon_1
fig = px.line(x=range(0, 768), y=recon_0.squeeze().detach().numpy(), title='Reconstructed 0')
fig.add_scatter(y=recon_1.squeeze().detach().numpy(), mode='lines', name='Reconstructed 1')
fig.show()

## Effect of L1 regularisation

In [1]:
import train
from tqdm import tqdm

def run_experiments(model_type: str, l1_coefficients: list, n_epochs: int = 1):
    results = {}
    for l1_coefficient in tqdm(l1_coefficients, desc="Running Experiments"):
        print(f"Running experiment with l1_coefficient={l1_coefficient}")
        final_recon_loss = train.main(model_type=model_type, n_epochs=n_epochs, l1_coefficient=l1_coefficient)
        results[l1_coefficient] = final_recon_loss
        print(f"Final Reconstruction Error for l1_coefficient={l1_coefficient}: {final_recon_loss:.4f}")
    return results

In [2]:
l1_coefficients = [0, 1e-9, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 2, 5, 10]
results = run_experiments(model_type='sparse', l1_coefficients=l1_coefficients, n_epochs=1)
print("Experiment Results:")
print(results)

Running Experiments:   0%|          | 0/11 [00:00<?, ?it/s]

Running experiment with l1_coefficient=0


  _torch_pytree._register_pytree_node(
Evaluating...: 100%|██████████| 450/450 [00:01<00:00, 339.87it/s]


Initial Test Loss 0.0605 | Initial Reconstruction Error 0.0605




Epoch 1



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Test Loss 0.0066 | Reconstruction Error 0.0066
Final Reconstruction Error for l1_coefficient=0: 0.0066
Running experiment with l1_coefficient=1e-09


Evaluating...: 100%|██████████| 450/450 [00:01<00:00, 280.10it/s]


Initial Test Loss 0.0614 | Initial Reconstruction Error 0.0614




Epoch 1



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A