In [1]:
import torch
from models import SparseAutoencoder  # Ensure this matches the model definition you uploaded
from huggingface_hub import hf_hub_download
import einops

# Set parameters
repo_name = "charlieoneill/sparse-coding"  # Adjust this with your repo name
model_filename = "sparse_autoencoder.pth"  # Name of the model file you uploaded
input_dim = 768  # Example input dim, adjust based on your model
hidden_dim = 22 * input_dim  # Projection up parameter * input_dim

# Download the model from Hugging Face Hub
model_path = hf_hub_download(repo_id=repo_name, filename=model_filename)

# Load the model
model = SparseAutoencoder(input_dim=input_dim, hidden_dim=hidden_dim)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()  # Set the model to evaluation model

sparse_autoencoder.pth:   0%|          | 0.00/104M [00:00<?, ?B/s]

SparseAutoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=768, out_features=16896, bias=True)
    (1): ReLU()
  )
  (decoder): Linear(in_features=16896, out_features=768, bias=True)
)

In [2]:
# Generate a sample batch of data (dummy data for this example)
batch_size = 64
seq_len = 128
dummy_input = torch.randn(batch_size, seq_len, input_dim)  # Shape: (batch_size, seq_len, input_dim)

# Reshape the input data as done in training
dummy_input = einops.rearrange(dummy_input, "batch pos d_model -> (batch pos) d_model")

# Pass one batch of data through the model
with torch.no_grad():
    sparse_output, reconstructed_output = model(dummy_input)

# Print the shapes of the outputs
print(f"Sparse output shape: {sparse_output.shape}")
print(f"Reconstructed output shape: {reconstructed_output.shape}")

Sparse output shape: torch.Size([8192, 16896])
Reconstructed output shape: torch.Size([8192, 768])


In [4]:
# Calculate losses
from main import loss_fn

recon_loss, l1_loss, l0_loss, total_loss = loss_fn(dummy_input, reconstructed_output, sparse_output)
print(f"Reconstruction loss: {recon_loss.item()}")
print(f"L1 loss: {l1_loss.item()}")
print(f"L0 loss: {l0_loss.item()}")

Reconstruction loss: 24.001211166381836
L1 loss: 3810.80517578125
L0 loss: 8368.216796875
