In [None]:
# Text Embedding Inversion using vec2text with OpenAI Embeddings
# This notebook demonstrates inverting OpenAI's text-embedding-ada-002 embeddings back to text


import vec2text
import torch
from openai import OpenAI
import os
from getpass import getpass
import pandas as pd
import time
import math

# Constants
embedding_model = "text-embedding-ada-002"
client = OpenAI()

# Load the pre-trained corrector
print(f"Loading vec2text corrector for OpenAI {embedding_model}...")
corrector = vec2text.load_pretrained_corrector("text-embedding-ada-002")

In [None]:
# %%
# Diverse test corpus including sensitive information
test_corpus = [  
    # Sensitive Information (to test security implications)
    "My social security number is 818-88-8118",
    "Password for admin account is SecurePass123!",
    "Credit card number 4532-1234-5678-9012 expires next month",
    "API key for production: sk-proj-1234567890abcdef",
    "Bank account routing number 021000021 checking 1234567890",
    
    # Medical Information
    "Patient diagnosed with Type 2 diabetes, prescribed metformin",
    "Blood test results show cholesterol level of 277 mg/dL",
    
    # Personal Information
    "John Doe lives at 1 World Trade Center, New York, NY 10007",
    "Phone number +1(212) 653-0688; email info@cyborg.co",
]

In [None]:
# Generate OpenAI Embeddings

def get_embeddings_openai(text_list) -> torch.Tensor:
    """Get embeddings from OpenAI API (batched for efficiency)"""
    print(f"Generating embeddings for {len(text_list)} texts...")
    
    batches = math.ceil(len(text_list) / 128)
    outputs = []
    for batch in range(batches):
        text_list_batch = text_list[batch * 128 : (batch + 1) * 128]
        response = client.embeddings.create(
            input=text_list_batch,
            model=embedding_model,
            encoding_format="float",
        )
        outputs.extend([e.embedding for e in response.data])
    return torch.tensor(outputs)

# Get embeddings as torch tensor
embeddings = get_embeddings_openai(test_corpus)
print(f"\nEmbedding shape: {embeddings.shape}")
print(f"Embedding dimension: {embeddings.shape[1]}")

# Move to same device as the model
# Check where the corrector model is
if hasattr(corrector, 'device'):
    embeddings = embeddings.to(corrector.device)
elif hasattr(corrector, 'model') and hasattr(corrector.model, 'device'):
    embeddings = embeddings.to(corrector.model.device)
elif torch.backends.mps.is_available():
    # On Apple Silicon
    embeddings = embeddings.to('mps')
elif torch.cuda.is_available():
    embeddings = embeddings.cuda()
    
print(f"Embeddings on device: {embeddings.device}")

In [None]:
# Invert embeddings
results = []

for i, (text, embedding) in enumerate(zip(test_corpus, embeddings)):
    print(f"Embedding #{i+1}:")
    print(f"Original Text: {text}")
    
    start_time = time.time()
    
    # Invert single embedding (need to add batch dimension)
    reconstructed_list = vec2text.invert_embeddings(
        embeddings=embedding.unsqueeze(0),  # Add batch dimension
        corrector=corrector,
        num_steps=20,  # More steps = better quality
        sequence_beam_width=1,  # Increase for better quality but slower
    )
    reconstructed = reconstructed_list[0]  # Get first (only) result
    
    inversion_time = time.time() - start_time
    
    print(f"Reconstructed: {reconstructed}")
    
    # Calculate similarity
    orig_emb_cpu = embedding.cpu()
    recon_emb = get_embeddings_openai([reconstructed])[0]
    similarity = torch.nn.functional.cosine_similarity(orig_emb_cpu, recon_emb, dim=0).item()
    
    print(f"Similarity: {similarity:.4f}")
    print(f"Inversion time: {inversion_time:.2f}s")
    
    # Check for exact match
    exact_match = text.lower().strip() == reconstructed.lower().strip()
    print(f"Exact match: {exact_match}")
    
    results.append({
        'original': text,
        'reconstructed': reconstructed,
        'similarity': similarity,
        'time': inversion_time,
        'exact_match': exact_match,
        'length_diff': abs(len(text) - len(reconstructed))
    })
    
    print("-"*100)