In [1]:
#!/usr/bin/env python3
import torch
from jumprelu_sae import JumpReLUSAE  # Make sure jumprelu_sae.py is in your PYTHONPATH or same directory
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi


def main():
    # Automatically select between CUDA and CPU.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load the activations.
    # (Assumes a file 'activations.pt' containing a tensor of shape [10000, 128, 2304])
    try:
        # Load your data from Hugging Face
        repo_id = "charlieoneill/gemma-medicine-sae"  # Replace with your repo

        # Download the activation tensor and dataset
        api = HfApi()
        activation_file = hf_hub_download(repo_id=repo_id, filename="10000_128.pt")

        # Load the tensors
        activations = torch.load(activation_file)
    except Exception as e:
        print(f"Error loading activations: {e}")
        return

    # Instantiate the SAE.
    # Use the last dimension of activations as d_model, and choose a latent dimension (d_sae).
    d_model = activations.shape[-1]  # e.g. 2304
    d_sae = 16384  # Example latent dimension; adjust as needed.
    sae = JumpReLUSAE(d_model=d_model, d_sae=d_sae, sparsity_coeff=100.0)
    sae.to(device)

    # Define training hyperparameters.
    batch_size = 1024     # Mini-batch size (each token is treated as a separate example)
    log_freq = 10        # How often to log training statistics
    lr = 1e-3             # Base learning rate

    print("Starting training...")

    # Train the model on the activations.
    # The optimize_on_activations method expects a tensor of shape (N, seq_len, d_model)
    # and will flatten the sequence dimension into the batch.
    data_log = sae.optimize_on_activations(
        activations,
        batch_size=batch_size,
        epochs=10,  # Changed from steps=steps to epochs=10 (or your desired number of epochs)
        log_freq=log_freq,
        lr=lr
    )

    print("Training completed.")

    # Print final statistics from the training log.
    if data_log:
        final_stats = data_log[-1]
        print("\nFinal training statistics:")
        for key, value in final_stats.items():
            print(f"{key}: {value}")

In [2]:
main()

Using device: cpu


  activations = torch.load(activation_file)


Starting training...
Total steps: 12500 (steps per epoch: 1250)
Epoch 0/10, Step 0: lr=0.001000, recon_loss=79.583023, sparsity_loss=7741.051758, frac_active=0.380379
Epoch 0/10, Step 10: lr=0.001000, recon_loss=48.909981, sparsity_loss=344.666992, frac_active=0.021176
Epoch 0/10, Step 20: lr=0.001000, recon_loss=21.469131, sparsity_loss=441.949219, frac_active=0.027181
Epoch 0/10, Step 30: lr=0.001000, recon_loss=16.837908, sparsity_loss=467.939453, frac_active=0.028789
Epoch 0/10, Step 40: lr=0.001000, recon_loss=14.451908, sparsity_loss=501.421875, frac_active=0.030780
Epoch 0/10, Step 50: lr=0.001000, recon_loss=12.341775, sparsity_loss=560.317383, frac_active=0.034555
Epoch 0/10, Step 60: lr=0.001000, recon_loss=10.886212, sparsity_loss=610.457031, frac_active=0.037544
Epoch 0/10, Step 70: lr=0.001000, recon_loss=9.977271, sparsity_loss=668.626953, frac_active=0.041090
Epoch 0/10, Step 80: lr=0.001000, recon_loss=9.330170, sparsity_loss=703.653320, frac_active=0.043200
Epoch 0/10,

KeyboardInterrupt: 