In [None]:
import torch.nn as nn
import torch.optim as optim

# Define Sparse Autoencoder
class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SparseAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(negative_slope=0.01),  # Helps sparsity
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # Output values between 0-1
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded



In [4]:
import torch
import ollama

# Function to get embeddings using Ollama (with error handling)
def get_embedding(text):
    try:
        embedding_data = ollama.embeddings(model="mxbai-embed-large", prompt=text)
        embedding_vector = torch.tensor(embedding_data["embedding"], dtype=torch.float32)  # Convert to tensor
        return embedding_vector
    except Exception as e:
        print(f"❌ Unexpected error generating embedding: {e}")
    return None


In [5]:
input_dim = 1024  # Match Ollama embedding size
hidden_dim = 64  # Compressed representation


In [2]:
import numpy as np

def extract_features(embeddings, sae_model):
    """Extracts sparse features from LLM embeddings."""
    with torch.no_grad():
        features = sae_model.encoder(torch.tensor(embeddings, dtype=torch.float32))
    return features.numpy()

# Simulated token embeddings (100-dimensional)
token_embeddings = np.random.rand(10, 100)  # 10 tokens, 100 features each
features = extract_features(token_embeddings, model)
print("Extracted sparse features:", features)
print("Feature shape:", features.shape)

Extracted sparse features: [[0.17408113 0.10068728 0.0266986  0.         0.         0.
  0.         0.08642361 0.         0.        ]
 [0.1961816  0.13727127 0.         0.         0.         0.03275107
  0.00662035 0.05166993 0.         0.        ]
 [0.1979455  0.09991568 0.         0.         0.         0.00894872
  0.         0.05397782 0.         0.        ]
 [0.17810443 0.19679238 0.         0.         0.         0.
  0.         0.30340973 0.         0.        ]
 [0.06918554 0.15267411 0.03933194 0.         0.         0.
  0.         0.08321735 0.         0.02684967]
 [0.04107302 0.15781884 0.05910005 0.         0.         0.
  0.         0.06845911 0.         0.00673035]
 [0.22568682 0.18624441 0.         0.         0.         0.
  0.         0.249841   0.         0.00278136]
 [0.27213317 0.12637684 0.         0.         0.         0.00529841
  0.         0.18601379 0.         0.        ]
 [0.24132699 0.16936474 0.         0.         0.         0.04576051
  0.         0.17641509 0