In [None]:
#!/usr/bin/env python3
import torch
from jumprelu_sae import JumpReLUSAE  # Make sure jumprelu_sae.py is in your PYTHONPATH or same directory
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi


def main():
    # Automatically select between CUDA and CPU.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load the activations.
    # (Assumes a file 'activations.pt' containing a tensor of shape [10000, 128, 2304])
    try:
        # Load your data from Hugging Face
        repo_id = "charlieoneill/gemma-medicine-sae"  # Replace with your repo

        # Download the activation tensor and dataset
        api = HfApi()
        activation_file = hf_hub_download(repo_id=repo_id, filename="10000_128.pt")

        # Load the tensors
        activations = torch.load(activation_file)
    except Exception as e:
        print(f"Error loading activations: {e}")
        return

    # Instantiate the SAE.
    # Use the last dimension of activations as d_model, and choose a latent dimension (d_sae).
    d_model = activations.shape[-1]  # e.g. 2304
    d_sae = 16384  # Example latent dimension; adjust as needed.
    sae = JumpReLUSAE(d_model=d_model, d_sae=d_sae, sparsity_coeff=100.0)
    sae.to(device)

    # Define training hyperparameters.
    batch_size = 128     # Mini-batch size (each token is treated as a separate example)
    log_freq = 10        # How often to log training statistics
    lr = 1e-3             # Base learning rate

    print("Starting training...")

    # Train the model on the activations.
    # The optimize_on_activations method expects a tensor of shape (N, seq_len, d_model)
    # and will flatten the sequence dimension into the batch.
    data_log = sae.optimize_on_activations(
        activations,
        batch_size=batch_size,
        epochs=10,  # Changed from steps=steps to epochs=10 (or your desired number of epochs)
        log_freq=log_freq,
        lr=lr
    )

    print("Training completed.")

    # Print final statistics from the training log.
    if data_log:
        final_stats = data_log[-1]
        print("\nFinal training statistics:")
        for key, value in final_stats.items():
            print(f"{key}: {value}")

In [None]:
main()