# Training Sparse Autoencoders (SAEs) for Enhanced Embedding Refinement

This notebook trains Sparse Autoencoders (SAEs) to refine embeddings from a Vision Transformer (ViT). The goal is to improve the separation of similar product embeddings by selecting the most relevant SAE features. The process includes data preparation, model training, feature selection, and evaluation with visualization.

---

## Install Required Packages

```python
%pip install torch torchvision scikit-learn tqdm faiss-cpu matplotlib

In [None]:
import os
import pickle
import json
import random
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
from glob import glob
import math
import matplotlib.pyplot as plt
import faiss
from PIL import Image

from utils.models.google_vit_model import GoogleViTModel
from utils.preprocessing import preprocess_image

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Load precomputed DAM and 3D embeddings
BASE_CACHE_DIR = '.cache'
selected_model_key = GoogleViTModel().model_name  # adjust if needed
BACKGROUND_REMOVAL_METHOD = 'RMBG_2'
EMBEDDING_AGGREGRATION_METHOD = "mean"

# Load DAM embeddings (including 3D augmentations if available)
embeddings_file = os.path.join(BASE_CACHE_DIR, "embeddings", f'dam_features-{selected_model_key}-{BACKGROUND_REMOVAL_METHOD}.pkl')
embeddings_file_3d = os.path.join(BASE_CACHE_DIR, "embeddings", f'dam_features-{selected_model_key}-rembg-3d.pkl')

with open(embeddings_file, 'rb') as f:
    dam_features = pickle.load(f)
    
# Replace each item with 8 duplicates with -X at the end of the key
for key in list(dam_features.keys()):
    for i in range(1, 8):
        dam_features[f"{key}-dup-{i}"] = dam_features[key]
    del dam_features[key]


if os.path.exists(embeddings_file_3d):
    with open(embeddings_file_3d, 'rb') as f:
        dam_features_3d = pickle.load(f)
    # Merge original and 3D features
    dam_features.update(dam_features_3d)

# Function to aggregate embeddings if needed
def aggregate_embedding(embedding):
    if len(embedding.shape) >= 2:
        # Assuming last dimension tokens, take mean over tokens
        embedding = embedding.squeeze()
        # Reshape into 2D matrix
        # embedding = embedding[1:]  # Remove CLS token
        # embedding = embedding.reshape(-1, 14, 14)
        
        embedding = embedding.mean(axis=0)
        
        # embedding = embedding.max(axis=0)
        
        # embedding = embedding[0, :]
    else:
        embedding = embedding.flatten()
    return embedding

# Process all embeddings: aggregate them for simplicity
for key in list(dam_features.keys()):
    dam_features[key] = aggregate_embedding(dam_features[key])
    # print(dam_features[key].shape)
    
print(f"Loaded {len(dam_features)} DAM embeddings")

In [None]:
class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=512, sparsity_lambda=1e-3):
        super(SparseAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # Assuming input embeddings are normalized
        )
        self.sparsity_lambda = sparsity_lambda

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded, encoded

    def sparsity_loss(self, encoded):
        # L1 sparsity penalty
        return self.sparsity_lambda * torch.mean(torch.abs(encoded))

In [None]:
class Autoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=512):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # Assuming input embeddings are normalized
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded, encoded

In [None]:
class VariationalAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=512, latent_dim=256):
        super(VariationalAutoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU()
        )
        # Latent space
        self.fc_mu = nn.Linear(hidden_dim // 2, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim // 2, latent_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # Assuming input embeddings are normalized
        )
        
    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)  # Sample from standard normal
        return mu + eps * std
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        reconstructed = self.decode(z)
        return reconstructed, mu, logvar

In [None]:
class EmbeddingDataset(Dataset):
    def __init__(self, dam_features):
        """
        dam_features: dict mapping image paths to aggregated embeddings.
        """
        self.paths = list(dam_features.keys())
        self.embeddings = np.array([dam_features[p].flatten() for p in self.paths], dtype=np.float32)
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self, idx):
        embedding = self.embeddings[idx]
        return torch.tensor(embedding, dtype=torch.float32)

In [None]:
# Sample embedding to determine input dimension
sample_embedding = next(iter(dam_features.values()))
input_dim = sample_embedding.flatten().shape[0]
print(f"Input dimension: {input_dim}")

# Model Selection
model_type = 'AE'  # Options: 'SAE', 'AE', 'VAE'

# Parameters (you can adjust these as needed)
input_dim = sample_embedding.flatten().shape[0]
hidden_dim = 256
latent_dim = 256  # Only used for VAE
sparsity_lambda = 1e-3  # Only used for SAE

# Instantiate the selected model
if model_type == 'SAE':
    autoencoder = SparseAutoencoder(input_dim=input_dim, hidden_dim=hidden_dim, sparsity_lambda=sparsity_lambda).to(device)
elif model_type == 'AE':
    autoencoder = Autoencoder(input_dim=input_dim, hidden_dim=hidden_dim).to(device)
elif model_type == 'VAE':
    autoencoder = VariationalAutoencoder(input_dim=input_dim, hidden_dim=hidden_dim, latent_dim=latent_dim).to(device)
else:
    raise ValueError(f"Unsupported model_type: {model_type}")
    
print(f"Selected model: {model_type}")

# Define loss functions
reconstruction_loss_fn = nn.MSELoss()

# Optimizer
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=1e-3)

# Create dataset and dataloader
embedding_dataset = EmbeddingDataset(dam_features)
embedding_dataloader = DataLoader(embedding_dataset, shuffle=True, batch_size=64)

In [None]:
num_epochs = 50  # Adjust as needed
epoch_losses = []

autoencoder.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for batch in tqdm(embedding_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        batch = batch.to(device)
        optimizer.zero_grad()
        
        if model_type == 'VAE':
            reconstructed, mu, logvar = autoencoder(batch)
            # Reconstruction loss
            loss_recon = reconstruction_loss_fn(reconstructed, batch)
            # KL Divergence loss
            kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
            loss = loss_recon + kl_loss
        elif model_type == 'SAE':
            reconstructed, encoded = autoencoder(batch)
            loss_recon = reconstruction_loss_fn(reconstructed, batch)
            loss_sparsity = autoencoder.sparsity_loss(encoded)
            loss = loss_recon + loss_sparsity
        else:  # Regular AE
            reconstructed, encoded = autoencoder(batch)
            loss_recon = reconstruction_loss_fn(reconstructed, batch)
            loss = loss_recon
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    avg_loss = running_loss / len(embedding_dataloader)
    epoch_losses.append(avg_loss)
    print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {avg_loss:.6f}")

In [None]:
plt.figure(figsize=(8, 5))
plt.plot(range(1, min(num_epochs+1, len(epoch_losses)+1)), epoch_losses, marker='o', label='Training Loss')
plt.title('Autoencoder Training Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.xticks(range(1, num_epochs+1))
plt.grid(True)
plt.legend()
plt.show()

In [None]:
# Extract encoded features for all embeddings
# Extract encoded features for all embeddings
autoencoder.eval()
encoded_features = {}

with torch.no_grad():
    for path, emb in tqdm(dam_features.items(), desc="Encoding embeddings"):
        emb_tensor = torch.tensor(emb.flatten(), dtype=torch.float32).unsqueeze(0).to(device)
        
        if model_type == 'VAE':
            reconstructed, mu, logvar = autoencoder(emb_tensor)
            encoded = mu  # Use the mean as the encoded feature
            # Optionally, you can store logvar if needed for analysis
            # encoded_features[path] = {"mu": mu.cpu().squeeze(0).numpy(), "logvar": logvar.cpu().squeeze(0).numpy()}
        else:
            reconstructed, encoded = autoencoder(emb_tensor)
            # encoded = encoded.cpu().squeeze(0).numpy()
        
        encoded_features[path] = encoded.cpu().squeeze(0).numpy()

# Convert to numpy array for further analysis
encoded_matrix = np.array(list(encoded_features.values()))
print(f"Encoded features shape: {encoded_matrix.shape}")

In [None]:
from sklearn.feature_selection import VarianceThreshold

# Initialize an empty dictionary to hold the final encoded features
selected_encoded_features = {}

if model_type != 'VAE':
    # Remove low-variance features which might be less informative
    selector = VarianceThreshold(threshold=0.01)  # Adjust threshold as needed
    selector.fit(encoded_matrix)
    selected_features = selector.get_support(indices=True)
    
    print(f"Selected {len(selected_features)} out of {encoded_matrix.shape[1]} features based on variance threshold.")
    
    # Apply feature selection
    encoded_matrix_selected = selector.transform(encoded_matrix)
    print(f"Shape after feature selection: {encoded_matrix_selected.shape}")
    
    # Update the encoded_features dictionary
    selected_encoded_features = {
        path: encoded_matrix_selected[idx] for idx, path in enumerate(encoded_features.keys())
    }
else:
    # For VAEs, skip feature selection and use all encoded features
    print("Feature selection is disabled for VAEs. Using all encoded features.")
    
    # No transformation needed
    encoded_matrix_selected = encoded_matrix
    selected_encoded_features = {
        path: encoded_matrix_selected[idx] for idx, path in enumerate(encoded_features.keys())
    }

In [None]:
# Visualize feature distributions before and after selection
plt.figure(figsize=(12, 5))

# Before selection
plt.subplot(1, 2, 1)
plt.hist(encoded_matrix.flatten(), bins=50, alpha=0.7, color='blue')
plt.title('Encoded Features Distribution (All)')
plt.xlabel('Feature Value')
plt.ylabel('Frequency')

# After selection
plt.subplot(1, 2, 2)
plt.hist(encoded_matrix_selected.flatten(), bins=50, alpha=0.7, color='green')
plt.title('Encoded Features Distribution (Selected)')
plt.xlabel('Feature Value')
plt.ylabel('Frequency')

plt.tight_layout()
plt.show()

In [None]:
# Paths to directories
dam_dir = 'data/DAM'
test_dir = 'data/test_image_headmind'
extensions = ['*.jpg', '*.jpeg', '*.png']

# Get list of image file paths for DAM and Test
dam_images = glob(os.path.join(dam_dir, '*.jpeg'))
dam_images.sort()

test_images = []
for ext in extensions:
    pattern = os.path.join(test_dir, ext)
    test_images.extend(glob(pattern))
test_images.sort()

# Create DataFrames
dam_df = pd.DataFrame({'image_path': dam_images})
test_df = pd.DataFrame({'image_path': test_images})

# Load labels CSV
labels_df = pd.read_csv('labels/handmade_test_labels.csv')
labels_dict = {}
for _, row in labels_df.iterrows():
    image_name = row['image'].strip()
    references = [ref.strip() for ref in str(row['reference']).split('/') if ref.strip() and ref.strip() != '?']
    labels_dict[image_name] = references

In [None]:
# Precompute encoded features for all DAM images
encoded_dam_features = {}

with torch.no_grad():
    for path in tqdm(dam_features.keys(), desc="Encoding DAM embeddings"):
        emb = dam_features[path].flatten()
        emb_tensor = torch.tensor(emb, dtype=torch.float32).unsqueeze(0).to(device)
        
        if model_type == 'VAE':
            reconstructed, mu, logvar = autoencoder(emb_tensor)
            encoded_np = mu.cpu().squeeze(0).numpy()
        else:
            reconstructed, encoded = autoencoder(emb_tensor)
            encoded_np = encoded.cpu().squeeze(0).numpy()
        
        if model_type != 'VAE':
            # Apply feature selection
            encoded_selected = selector.transform(encoded_np.reshape(1, -1)).squeeze(0)
        else:
            # Use all features
            encoded_selected = encoded_np
        
        encoded_dam_features[path] = encoded_selected

In [None]:
# Initialize VIT model for extracting test embeddings
vit_model = GoogleViTModel()
vit_model_device = device  # Use same device

# Precompute encoded features for each test image
encoded_test_features = {}

for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Encoding Test embeddings"):
    test_path = row['image_path']
    test_obj = preprocess_image(test_path, background_removal="RMBG_2")
    if test_obj is None:
        continue
    with torch.no_grad():
        # Extract VIT embedding and aggregate
        test_feat = vit_model.extract_features(test_obj)
    test_emb = aggregate_embedding(test_feat).flatten()
    # Pass through Autoencoder's encoder
    test_tensor = torch.tensor(test_emb, dtype=torch.float32).unsqueeze(0).to(device)
    
    if model_type == 'VAE':
        reconstructed, mu, logvar = autoencoder(test_tensor)
        encoded_np = mu.detach().cpu().squeeze(0).numpy()
    else:
        reconstructed, encoded = autoencoder(test_tensor)
        encoded_np = encoded.detach().cpu().squeeze(0).numpy()
    
    if model_type != 'VAE':
        # Apply feature selection
        encoded_selected = selector.transform(encoded_np.reshape(1, -1)).squeeze(0)
    else:
        # Use all features
        encoded_selected = encoded_np
    
    encoded_test_features[test_path] = encoded_selected

In [None]:
# Convert encoded DAM features to numpy array
dam_paths = list(encoded_dam_features.keys())
dam_encoded_matrix = np.array([encoded_dam_features[p] for p in dam_paths], dtype=np.float32)

# Build FAISS index
dim = dam_encoded_matrix.shape[1]
index = faiss.IndexFlatL2(dim)  # Using Euclidean distance
index.add(dam_encoded_matrix)
print(f"FAISS index built with {index.ntotal} vectors, dimension={dim}.")

In [None]:
def find_best_match(test_feature, dam_features, index, dam_paths, top_n=1):
    """
    Find the best matching DAM image for a given test feature.
    """
    test_feature = np.expand_dims(test_feature, axis=0).astype(np.float32)
    distances, indices = index.search(test_feature, top_n)
    best_matches = [(dam_paths[idx], distances[0][i]) for i, idx in enumerate(indices[0])]
    return best_matches

In [None]:
correct_top1 = 0
total = 0

for test_path, test_feat in tqdm(encoded_test_features.items(), desc="Evaluating"):
    test_path = test_path.split("-dup")[0]
    top_matches = find_best_match(test_feat, encoded_dam_features, index, dam_paths, top_n=1)
    predicted_path, distance = top_matches[0]
    predicted_path = predicted_path.split("-dup")[0]
    predicted_code = os.path.basename(predicted_path).split('.')[0].split('-')[0]
    
    # Get true references for current test image
    test_image_name = os.path.basename(test_path)
    true_references = labels_dict.get(test_image_name, [])
    
    if predicted_code in true_references:
        correct_top1 += 1
    total += 1
    
    # Plot test image, predicted and true match images side by side
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Display test image
    try:
        test_img = Image.open(test_path)
        axes[0].imshow(test_img)
    except:
        axes[0].text(0.5, 0.5, "Image not found",
                     horizontalalignment='center', verticalalignment='center', fontsize=12)
    axes[0].set_title("Test Image")
    axes[0].axis('off')
    
    # Display predicted image
    try:
        pred_img = Image.open(predicted_path)
        axes[1].imshow(pred_img)
    except:
        axes[1].text(0.5, 0.5, "Image not found",
                     horizontalalignment='center', verticalalignment='center', fontsize=12)
    axes[1].set_title(f"Predicted Match\n{os.path.basename(predicted_path)}")
    axes[1].axis('off')
    
    # Display true match image if found
    true_match_path = None
    for ref in true_references:
        for path in dam_features.keys():
            if os.path.basename(path).startswith(ref):
                true_match_path = path
                break
        if true_match_path:
            break
    
    true_match_path = true_match_path.split("-dup")[0] if true_match_path else None
    
    if true_match_path:
        try:
            true_img = Image.open(true_match_path)
            axes[2].imshow(true_img)
        except:
            axes[2].text(0.5, 0.5, "Image not found",
                         horizontalalignment='center', verticalalignment='center', fontsize=12)
    else:
        axes[2].text(0.5, 0.5, "True match not found",
                     horizontalalignment='center', verticalalignment='center', fontsize=12)
    axes[2].set_title("True Match")
    axes[2].axis('off')
    
    plt.show()

accuracy_top1 = correct_top1 / total if total > 0 else 0
print(f"Top-1 Accuracy on test set: {accuracy_top1:.2%} ({correct_top1}/{total})")

In [None]:
# Save trained Autoencoder model
model_save_path = f".cache/models/{model_type.lower()}_autoencoder.pth"
torch.save(autoencoder.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

# Save benchmark results
benchmark = {
    "model_type": model_type,
    "top_1_accuracy": accuracy_top1
}
benchmark_save_path = os.path.join("benchmarks", f"{model_type.lower()}_benchmark.json")
with open(benchmark_save_path, 'w') as f:
    json.dump(benchmark, f)
print(f"Benchmark results saved to {benchmark_save_path}")