In [1]:
# Upload to huggingface
from huggingface_hub import HfApi, HfFolder
import torch
from gated_sae import GatedSAE
from vanilla_sae import SparseAutoencoder
import plotly.express as px

In [2]:
# gated_sae = torch.load('data/gated_sae.pt', map_location=torch.device('cpu'))

import torch
from huggingface_hub import hf_hub_download
from gated_sae import GatedSAE

# Define the function to download and load the model
def load_gated_sae(repo_id, filename, n_input_features, projection_up, l1_coefficient):
    # Download the model file from HuggingFace Hub
    file_path = hf_hub_download(repo_id=repo_id, filename=filename)
    
    # Initialize the model architecture
    # model = GatedSAE(
    #     n_input_features=n_input_features, 
    #     n_learned_features=n_input_features * projection_up, 
    #     l1_coefficient=l1_coefficient
    # )
    
    # Load the state dict into the model
    # model.load_state_dict(torch.load(file_path, map_location=torch.device('cpu')))
    model = torch.load(file_path, map_location=torch.device('cpu'))
    
    return model

# Define parameters
repo_id = 'charlieoneill/error-saes'
filename = 'sae_layer_9.pt'
n_input_features = 768
projection_up = 8
l1_coefficient = 1e-4

# Load the model
gated_sae = load_gated_sae(repo_id, filename, n_input_features, projection_up, l1_coefficient)

# Print model summary (optional)
print(gated_sae)

GatedSAE(
  (activation_fn): ReLU()
)


In [3]:
gated_sae

GatedSAE(
  (activation_fn): ReLU()
)

In [4]:
sae_errors = torch.load('data/sae_errors.pt')
original_z = torch.load('data/original_z.pt')
print(sae_errors.shape, original_z.shape)

torch.Size([143872, 768]) torch.Size([143872, 768])


In [5]:
idx = 1
x = original_z[idx, :].unsqueeze(0)
y = sae_errors[idx, :].unsqueeze(0)

print(f"Original Z act: {x[0][0]:.3f}")
print(f"SAE error: {y[0][0]:.3f}")

sae_pred = original_z - sae_errors
print(f"SAE pred: {sae_pred[idx, 0]:.3f}")

recon, loss, recon_loss = gated_sae(x, y)
print(f"Recon: {recon[0][0]:.3f}") 
print(f"Loss: {loss:.3f}")
print(f"Recon loss: {recon_loss:.3f}")

Original Z act: -0.124
SAE error: -0.280
SAE pred: 0.156
Recon: 0.035
Loss: 0.038
Recon loss: 0.017


  _torch_pytree._register_pytree_node(


In [6]:
# Original z line plot
fig = px.line(x=range(0, 768), y=original_z[idx, :].detach().numpy(), title='Original z')
fig.add_scatter(y=sae_pred[idx, :].detach().numpy(), mode='lines', name='SAE Pred')
fig.show()

# Original z line plot
fig = px.line(x=range(0, 768), y=sae_errors[idx, :].detach().numpy(), title='SAE Error')
fig.add_scatter(y=recon.squeeze().detach().numpy(), mode='lines', name='Reconstructed Error')
fig.show()

In [7]:
idx = 0
x = original_z[idx, :].unsqueeze(0)
y = sae_errors[idx, :].unsqueeze(0)
sae_pred = original_z - sae_errors
recon_0, loss, _ = gated_sae(x, y)


idx = 1
x = original_z[idx, :].unsqueeze(0)
y = sae_errors[idx, :].unsqueeze(0)
sae_pred = original_z - sae_errors
recon_1, loss, _ = gated_sae(x, y)
recon_1 += 0.0

# Plot recon_0 and recon_1
fig = px.line(x=range(0, 768), y=recon_0.squeeze().detach().numpy(), title='Reconstructed 0')
fig.add_scatter(y=recon_1.squeeze().detach().numpy(), mode='lines', name='Reconstructed 1')
fig.show()

## Effect of L1 regularisation

In [1]:
import train
from tqdm import tqdm
import plotly.express as px

def run_experiments(model_type: str, l1_coefficients: list, n_epochs: int = 1):
    results = {}
    for l1_coefficient in tqdm(l1_coefficients, desc="Running Experiments"):
        print(f"Running experiment with l1_coefficient={l1_coefficient}")
        final_recon_loss, l0_loss = train.main(model_type=model_type, n_epochs=n_epochs, l1_coefficient=l1_coefficient)
        results[l1_coefficient] = (final_recon_loss, l0_loss)
        print(f"Final Reconstruction Error for l1_coefficient={l1_coefficient}: {final_recon_loss:.4f}")
    return results

In [2]:
l1_coefficients = [5e-5, 8e-5, 1e-4, 3e-4]
results = run_experiments(model_type='gated', l1_coefficients=l1_coefficients, n_epochs=2)
print("Experiment Results:")
print(results)

Running Experiments:   0%|          | 0/4 [00:00<?, ?it/s]

Running experiment with l1_coefficient=5e-05
Creating new dataloaders...


  _torch_pytree._register_pytree_node(
Evaluating...: 100%|██████████| 450/450 [00:02<00:00, 208.06it/s]


Initial Test Loss 0.3398 | Initial Reconstruction Error 0.1622 | Initial L0 Loss 3041.0969




Epoch 1



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Test Loss 0.0257 | Reconstruction Error 0.0119 | L0 Loss 190.3671
Epoch 2



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Test Loss 0.0242 | Reconstruction Error 0.0112 | L0 Loss 192.7533
Final Reconstruction Error for l1_coefficient=5e-05: 0.0112
Running experiment with l1_coefficient=8e-05
Loading existing dataloaders...


Evaluating...: 100%|██████████| 450/450 [00:02<00:00, 174.61it/s]


Initial Test Loss 0.3619 | Initial Reconstruction Error 0.1683 | Initial L0 Loss 3108.6596




Epoch 1



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Test Loss 0.0275 | Reconstruction Error 0.0127 | L0 Loss 151.5951
Epoch 2



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Test Loss 0.0260 | Reconstruction Error 0.0120 | L0 Loss 154.6644
Final Reconstruction Error for l1_coefficient=8e-05: 0.0120
Running experiment with l1_coefficient=0.0001
Loading existing dataloaders...


Evaluating...: 100%|██████████| 450/450 [00:02<00:00, 176.61it/s]


Initial Test Loss 0.3624 | Initial Reconstruction Error 0.1657 | Initial L0 Loss 3080.4782




Epoch 1



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Test Loss 0.0286 | Reconstruction Error 0.0131 | L0 Loss 124.8363
Epoch 2



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Test Loss 0.0270 | Reconstruction Error 0.0123 | L0 Loss 135.4560
Final Reconstruction Error for l1_coefficient=0.0001: 0.0123
Running experiment with l1_coefficient=0.0003
Loading existing dataloaders...


Evaluating...: 100%|██████████| 450/450 [00:02<00:00, 172.26it/s]


Initial Test Loss 0.4241 | Initial Reconstruction Error 0.1652 | Initial L0 Loss 3091.2309




Epoch 1



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Test Loss 0.0345 | Reconstruction Error 0.0160 | L0 Loss 41.7433
Epoch 2



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Test Loss 0.0327 | Reconstruction Error 0.0151 | L0 Loss 41.9012
Final Reconstruction Error for l1_coefficient=0.0003: 0.0151
Experiment Results:
{5e-05: (0.011243170698483785, 192.75325676812065), 8e-05: (0.011950277843409114, 154.66444888644747), 0.0001: (0.012324780451340807, 135.45596065945097), 0.0003: (0.015099402125924826, 41.90117432488336)}





In [3]:
results

{5e-05: (0.011243170698483785, 192.75325676812065),
 8e-05: (0.011950277843409114, 154.66444888644747),
 0.0001: (0.012324780451340807, 135.45596065945097),
 0.0003: (0.015099402125924826, 41.90117432488336)}

In [4]:
recon_errors = [result[0] for result in results.values()]
l0_errors = [result[1] for result in results.values()]

In [5]:
# Plotly plot results
fig = px.line(x=l1_coefficients, y=recon_errors, 
              labels={'x': 'L1 Coefficient', 'y': 'Reconstruction Error'},
              title='Reconstruction Error vs L1 Coefficient (Gated SAE)', width=600)
# Log x-axis
fig.update_xaxes(type="log")
fig.show()

In [6]:
# L0 loss
fig = px.line(x=l1_coefficients, y=l0_errors, 
              labels={'x': 'L1 Coefficient', 'y': 'L0 Loss'},
              title='L0 Loss vs L1 Coefficient (Gated SAE)', width=600)
# Log x-axis
fig.update_xaxes(type="log")
fig.show()

## Evaluating the SAE with and without SAE error

In [22]:
# Load in the test dataset
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from train import GatedSAEDataset
import sys
sys.path.append('../src')
from circuit_lens import get_model_encoders

# Load in the test dataset
test_dataset = torch.load('data/test_dataset.pt')
test_dataset[0][0].shape, test_dataset[0][1].shape

(torch.Size([768]), torch.Size([768]))

In [23]:
# Get GPT2 (model), and our SAE
model, z_saes, _ = get_model_encoders(device='cpu')
layer = 9 
sae = z_saes[layer]
del z_saes


torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.



Loaded pretrained model gpt2-small into HookedTransformer


100%|██████████| 12/12 [00:09<00:00,  1.33it/s]
100%|██████████| 12/12 [00:05<00:00,  2.29it/s]


In [46]:
z_original = test_dataset[0][0].unsqueeze(0)
loss, z_reconstruct, acts, l2_loss, l1_loss = sae(z_original)
error = z_original.float() - z_reconstruct.float()
l2_loss_ours = (z_reconstruct.float() - z_original.float()).pow(2).sum(-1)

# Print all shapes
print(f"Original Z shape: {z_original.shape}")
print(f"Reconstructed Z shape: {z_reconstruct.shape}")
print(f"Error shape: {error.shape}")
print(f"Layer {layer} activations shape: {acts.shape}")
print(f"L2 Loss: {l2_loss:.4f}")
print(f"L2 Loss (Ours): {l2_loss_ours.item():.4f}")
print(f"L1 Loss: {l1_loss:.4f}")
print(f"Mean error = {error.abs().mean(-1).item():.4f}")

Original Z shape: torch.Size([1, 768])
Reconstructed Z shape: torch.Size([1, 768])
Error shape: torch.Size([1, 768])
Layer 9 activations shape: torch.Size([1, 24576])
L2 Loss: 27.6021
L2 Loss (Ours): 27.6021
L1 Loss: 4.7549
Mean error = 0.1198


In [48]:
error.abs().mean(-1)

tensor([0.1198], grad_fn=<MeanBackward1>)

In [29]:
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=False)

In [53]:
# Initial metrics: l2 loss, l1 loss, mean error
evaluation_results = []
for i, (original_z, _) in enumerate(tqdm(test_dataloader)):
    original_z = original_z.to('cpu')
    loss, z_reconstruct, acts, l2_loss, l1_loss = sae(original_z)
    error = original_z.float() - z_reconstruct.float()
    mean_error = error.abs().sum(-1).mean()
    evaluation_results.append((l2_loss.item(), mean_error.item()))

# Divide by number of batches
l2_loss, mean_error = zip(*evaluation_results)
l2_loss = sum(l2_loss) / len(test_dataloader)
mean_error = sum(mean_error) / len(test_dataloader)

# Print
print(f"L2 Loss: {l2_loss:.4f}")
print(f"Mean Error: {mean_error:.4f}")

100%|██████████| 113/113 [00:02<00:00, 46.33it/s]

L2 Loss: 20.6656
Mean Error: 79.7133





In [55]:
# Now, we test the same metrics but with our trained gated SAE
gated_sae = torch.load('data/gated_sae.pt', map_location=torch.device('cpu'))
gated_sae = gated_sae.to('cpu')
gated_sae.eval()

# Initial metrics: l2 loss, l1 loss, mean error
evaluation_results = []
for i, (original_z, _) in enumerate(tqdm(test_dataloader)):
    original_z = original_z.to('cpu')
    _, z_reconstruct, acts, _, _ = sae(original_z)
    error = original_z - z_reconstruct
    predicted_error, _, _ = gated_sae(original_z, error)
    # Add the predicted error to the z_reconstruct
    z_reconstruct = z_reconstruct + predicted_error
    new_error = original_z - z_reconstruct
    # Evaluate the same metrics as before
    l2_loss = (z_reconstruct - original_z).pow(2).sum(-1).mean().item()
    mean_error = new_error.abs().sum(-1).mean().item()
    evaluation_results.append((l2_loss, mean_error))

# Divide by number of batches
l2_loss, mean_error = zip(*evaluation_results)
l2_loss = sum(l2_loss) / len(test_dataloader)
mean_error = sum(mean_error) / len(test_dataloader)

# Print
print(f"L2 Loss: {l2_loss:.4f}")
print(f"Mean Error: {mean_error:.4f}")
    

100%|██████████| 113/113 [00:03<00:00, 32.33it/s]

L2 Loss: 9.2236
Mean Error: 56.9743





In [61]:
import einops
from torch.utils.data import DataLoader, Dataset

class TokenizedDataset(Dataset):
    def __init__(self, tokenized_dataset):
        self.tokenized_dataset = tokenized_dataset

    def __len__(self):
        return len(self.tokenized_dataset)

    def __getitem__(self, idx):
        return self.tokenized_dataset[idx]

tokenized_dataset = torch.load('data/tokenized_dataset.pt')
dataset = TokenizedDataset(tokenized_dataset)
print(f"Length of tokenised dataset = {len(tokenized_dataset)}")

dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Disable torch grad
torch.set_grad_enabled(False)

# Get the first batch
batch = next(iter(dataloader))
_, cache = model.run_with_cache(batch)

Length of tokenised dataset = 1124


In [63]:
cache["z", 9].shape # batch, seq, n_heads, head_dim

torch.Size([16, 128, 12, 64])

In [77]:
# Now we want some more comprehensive metrics
# Basically, if we patch in the original SAE reconstruction, and then original SAE + predicted error
# We will compare this to zero ablation and random ablation as a baseline
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name
from functools import partial
import torch
import torch.nn.functional as F

def calculate_kl_divergence(clean_logits, patched_logits):
    # Ensure the inputs are log probabilities
    clean_log_probs = F.log_softmax(clean_logits, dim=-1)
    patched_log_probs = F.log_softmax(patched_logits, dim=-1)
    
    # Convert patched_logits to probabilities
    patched_probs = torch.exp(patched_log_probs)
    
    # Calculate KL divergence for each element in the batch and sequence
    kl_div = F.kl_div(clean_log_probs, patched_probs, reduction='none')
    
    # Average over the vocabulary size (last dimension)
    kl_div = kl_div.sum(dim=-1)
    
    # Average over the batch and sequence length
    kl_div = kl_div.mean(dim=0).mean(dim=0)
    
    return kl_div.item()

def attention_head_z_patching_hook(attention_head_z, hook: HookPoint, layer: int, sae: SparseAutoencoder, gated_sae: GatedSAE):
    z_acts = einops.rearrange(attention_head_z, "b s h d -> (b s) (h d)")
    if sae is not None:
        # Get the reconstructions from the SAE
        _, z_reconstruct, _, _, _ = sae(z_acts)
    else:
        z_reconstruct = torch.zeros_like(z_acts)
    if gated_sae is not None:
        # Get the error
        error = z_acts - z_reconstruct
        # Get the predicted error
        predicted_error, _, _ = gated_sae(z_acts, error)
        # Add the predicted error to the z_reconstruct
        z_reconstruct = z_reconstruct + predicted_error
    # Rearrange back into original shape
    z_reconstruct = einops.rearrange(z_reconstruct, "(b s) (h d) -> b s h d", b=attention_head_z.shape[0], s=attention_head_z.shape[1], h=attention_head_z.shape[2], d=attention_head_z.shape[3])
    attention_head_z = z_reconstruct
    return attention_head_z

# Let's try running this on the first batch
clean_logits, clean_loss = model(batch, return_type="both")

hook_fn = partial(attention_head_z_patching_hook, layer=9, sae=None, gated_sae=None)
patched_logits, patched_loss = model.run_with_hooks(
    batch,
    fwd_hooks=[(get_act_name("z", layer, "attn"), hook_fn)],
    return_type="both"
)

In [78]:
clean_loss, patched_loss

(tensor(3.8419), tensor(3.9419))

In [79]:
kl_divergence = calculate_kl_divergence(clean_logits, patched_logits)
print(f"KL Divergence: {kl_divergence}")

KL Divergence: 0.0716462954878807


In [81]:
# Now we need to write a function to do it for all batches
def run_ablation_experiment(dataloader, model, sae, gated_sae):
    kl_divergences, loss_differences = [], []
    for i, batch in enumerate(tqdm(dataloader)):
        clean_logits, clean_loss = model(batch, return_type="both")
        hook_fn = partial(attention_head_z_patching_hook, layer=9, sae=sae, gated_sae=gated_sae)
        patched_logits, patched_loss = model.run_with_hooks(
            batch,
            fwd_hooks=[(get_act_name("z", 9, "attn"), hook_fn)],
            return_type="both"
        )
        kl_divergence = calculate_kl_divergence(clean_logits, patched_logits)
        kl_divergences.append(kl_divergence)
        loss_difference = patched_loss - clean_loss
        loss_differences.append(loss_difference)

    # Average and return
    kl_divergence = sum(kl_divergences) / len(dataloader)
    loss_difference = sum(loss_differences) / len(dataloader)

    # Normalise the loss difference 
    # loss_difference = loss_difference / clean_loss
    return kl_divergence, loss_difference

kl_divergence, loss_difference = run_ablation_experiment(dataloader, model, sae, gated_sae)
print(f"KL Divergence: {kl_divergence}")
print(f"Loss Difference: {loss_difference}")

100%|██████████| 71/71 [02:12<00:00,  1.87s/it]

KL Divergence: 0.007976958023148104
Loss Difference: -0.0004855985171161592





In [82]:
# Without using the gated SAE
kl_divergence, loss_difference = run_ablation_experiment(dataloader, model, sae, gated_sae=None)
print(f"KL Divergence: {kl_divergence}")
print(f"Loss Difference: {loss_difference}")

100%|██████████| 71/71 [02:12<00:00,  1.86s/it]

KL Divergence: 0.01950965033398128
Loss Difference: 0.0069394903257489204





In [83]:
# Zero ablation
kl_divergence, loss_difference = run_ablation_experiment(dataloader, model, sae=None, gated_sae=None)
print(f"KL Divergence: {kl_divergence}")
print(f"Loss Difference: {loss_difference}")

100%|██████████| 71/71 [02:00<00:00,  1.69s/it]

KL Divergence: 0.06767696436022369
Loss Difference: 0.026526078581809998





In [87]:
from huggingface_hub import HfApi, HfFolder
# Import login
from huggingface_hub import notebook_login

notebook_login()

# Ensure you are logged in to HuggingFace
# You can use the following command to log in via the terminal
# huggingface-cli login

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [88]:
api = HfApi()
username = "charlieoneill"
repo_name = "error-saes"
repo_id = f"{username}/{repo_name}"

repo_url = api.create_repo(repo_id=repo_id, private=False, token="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")
print(f"Repository {repo_id} created at: {repo_url}")

In [None]:
# upload gated saes to the repo
hf_folder = HfFolder()
hf_folder.push_to_hub(repo_id=repo_id, token="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")

In [None]:
repo_name = "charlieoneill/error-saes"

file_name = "sae_layer_9.pt"

gated_sae 

In [1]:
import torch
from huggingface_hub import hf_hub_download
from gated_sae import GatedSAE

# Define the function to download and load the model
def load_gated_sae(repo_id, filename, n_input_features, projection_up, l1_coefficient):
    # Download the model file from HuggingFace Hub
    file_path = hf_hub_download(repo_id=repo_id, filename=filename)
    
    # Initialize the model architecture
    # model = GatedSAE(
    #     n_input_features=n_input_features, 
    #     n_learned_features=n_input_features * projection_up, 
    #     l1_coefficient=l1_coefficient
    # )
    
    # Load the state dict into the model
    # model.load_state_dict(torch.load(file_path, map_location=torch.device('cpu')))
    model = torch.load(file_path, map_location=torch.device('cpu'))
    
    return model

# Define parameters
repo_id = 'charlieoneill/error-saes'
filename = 'sae_layer_9.pt'
n_input_features = 768
projection_up = 8
l1_coefficient = 1e-4

# Load the model
model = load_gated_sae(repo_id, filename, n_input_features, projection_up, l1_coefficient)

# Print model summary (optional)
print(model)

GatedSAE(
  (activation_fn): ReLU()
)


In [2]:
model

GatedSAE(
  (activation_fn): ReLU()
)