# LSHRS Demo 2: PostgreSQL Integration & Reranking

## Overview

This notebook demonstrates a **production-grade workflow**:

- **Streaming Data**: Load vectors directly from PostgreSQL
- **Index Building**: Create LSH signatures and store them in Redis
- **Two-Stage Search**:
  1. Fast approximate matching using LSH
  2. Accurate re-ranking using cosine similarity
- **Performance vs Accuracy**: Observe the trade-off between speed and precision

### Architecture

```
PostgreSQL (Cold Storage)  â†’  Streaming Iterator
                                    â†“
                          LSHRS (Hashing)
                                    â†“
                          Redis (Hot Index)
                                    â†“
                          Query Execution
                             â”œâ”€ LSH Lookup (Fast)
                             â””â”€ Cosine Rerank (Accurate)
```

In [None]:
import numpy as np
import pandas as pd
import time
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict

# Database
import psycopg2
from sqlalchemy import create_engine, text

# LSHRS
from lshrs import LSHRS

# Configuration
DB_URL = "postgresql://postgres:changeme@localhost:5432/demo"
REDIS_HOST = "localhost"
REDIS_PORT = 6379
DIM = 128
NUM_PRODUCTS = 2000  # ~2K product embeddings
SEED = 42

print("âœ“ Imports successful")
print(f"  Database URL: {DB_URL}")
print(f"  Vector Dimension: {DIM}")
print(f"  Target Product Count: {NUM_PRODUCTS}")

## Section 1: Database Setup & Population

Create a PostgreSQL table with product embeddings.

In [None]:
def setup_product_database():
    """Create products table with embeddings."""
    engine = create_engine(DB_URL)
    
    with engine.connect() as conn:
        # Enable pgvector extension if available
        try:
            conn.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
        except:
            print("  Note: pgvector not available, using FLOAT[] arrays")
        
        # Drop existing table
        conn.execute(text("DROP TABLE IF EXISTS products CASCADE;"))
        
        # Create table with embeddings
        conn.execute(text("""
            CREATE TABLE products (
                id SERIAL PRIMARY KEY,
                sku VARCHAR(50) UNIQUE NOT NULL,
                category VARCHAR(50) NOT NULL,
                name VARCHAR(255) NOT NULL,
                price DECIMAL(10, 2) NOT NULL,
                embedding FLOAT[] NOT NULL,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            );
        """))
        
        # Create index on category for filtering
        conn.execute(text("CREATE INDEX idx_products_category ON products(category);"))
        
        conn.commit()
        print("  âœ“ Table created")

# Run setup
print("Setting up database...")
setup_product_database()

## Section 2: Generate Realistic Product Data

Create synthetic product embeddings organized by category.

In [ ]:
def populate_products():
    """Generate synthetic products with category-specific embeddings."""
    engine = create_engine(DB_URL)
    
    categories = ['electronics', 'clothing', 'books', 'home', 'sports']
    np.random.seed(SEED)
    
    products = []
    
    print(f"Generating {NUM_PRODUCTS} products...")
    
    for i in range(NUM_PRODUCTS):
        # Assign category (products per category)
        cat_idx = i % len(categories)
        category = categories[cat_idx]
        
        # Create category-biased embedding
        # Each category gets a bias in a different region of embedding space
        embedding = np.random.randn(DIM).astype(np.float32) * 0.5
        embedding[cat_idx * 25 : (cat_idx + 1) * 25] += np.random.randn(25).astype(np.float32) * 2.0
        
        products.append({
            'sku': f'SKU-{i:06d}',
            'category': category,
            'name': f'{category.capitalize()} Product {i}',
            'price': np.random.uniform(10, 1000),
            'embedding': embedding.tolist()
        })
    
    # Batch insert
    with engine.connect() as conn:
        conn.execute(
            text("INSERT INTO products (sku, category, name, price, embedding) VALUES (:sku, :category, :name, :price, :embedding)"),
            products
        )
        conn.commit()
    
    print(f"  âœ“ {NUM_PRODUCTS} products inserted")

populate_products()

# Verify
engine = create_engine(DB_URL)
with engine.connect() as conn:
    result = conn.execute(text("SELECT COUNT(*), COUNT(DISTINCT category) FROM products;")).fetchone()
    print(f"  âœ“ Database check: {result[0]} products, {result[1]} categories")

## Section 3: Define Vector Fetcher Callback

LSHRS needs a way to retrieve full vectors for re-ranking. We define a callback that fetches from Postgres.

In [None]:
def fetch_vectors_from_db(indices: List[int]) -> np.ndarray:
    """
    Fetch product embeddings from PostgreSQL.
    
    Args:
        indices: List of product IDs
    
    Returns:
        np.ndarray of shape (len(indices), DIM)
    """
    if not indices:
        return np.array([], dtype=np.float32).reshape(0, DIM)
    
    try:
        conn = psycopg2.connect(DB_URL)
        cur = conn.cursor()
        
        # Fetch embeddings
        query = "SELECT id, embedding FROM products WHERE id = ANY(%s) ORDER BY id;"
        cur.execute(query, (list(indices),))
        rows = cur.fetchall()
        
        cur.close()
        conn.close()
        
        # Reconstruct in correct order
        lookup = {row[0]: np.array(row[1], dtype=np.float32) for row in rows}
        vectors = np.array([lookup.get(idx, np.zeros(DIM, dtype=np.float32)) for idx in indices])
        
        return vectors.astype(np.float32)
    except Exception as e:
        print(f"ERROR fetching vectors: {e}")
        return np.zeros((len(indices), DIM), dtype=np.float32)

# Test the fetcher
test_vectors = fetch_vectors_from_db([1, 2, 3])
print(f"âœ“ Vector fetcher functional")
print(f"  Shape: {test_vectors.shape}")
print(f"  Sample norm: {np.linalg.norm(test_vectors[0]):.4f}")

## Section 4: Build LSH Index from PostgreSQL

Stream vectors from Postgres and create the index.

In [None]:
# Initialize LSHRS
lsh = LSHRS(
    dim=DIM,
    similarity_threshold=0.6,  # Balanced threshold
    vector_fetch_fn=fetch_vectors_from_db,  # Enable re-ranking
    redis_host=REDIS_HOST,
    redis_port=REDIS_PORT,
    redis_prefix='demo_pg',
    seed=SEED
)

lsh.clear()

# Display configuration
stats = lsh.stats()
print(f"LSH Configuration:")
print(f"  Bands: {stats['num_bands']}")
print(f"  Rows/Band: {stats['rows_per_band']}")
print(f"  Total Bits: {stats['num_bands'] * stats['rows_per_band']}")
print(f"  Seed: {stats['seed']}")

# Stream from database
print(f"\nStreaming from PostgreSQL...")
start_time = time.time()

lsh.create_signatures(
    format="postgres",
    dsn=DB_URL,
    table="products",
    index_column="id",
    vector_column="embedding",
    batch_size=500,
    where_clause=None  # Index all products
)

index_time = time.time() - start_time
print(f"âœ“ Indexing complete in {index_time:.2f}s")
print(f"  Throughput: {NUM_PRODUCTS / index_time:.0f} vectors/sec")

## Section 5: Query with Two-Stage Retrieval

Execute queries and compare LSH-only vs LSH+Reranking.

In [None]:
# Get a query product
engine = create_engine(DB_URL)
with engine.connect() as conn:
    result = conn.execute(text("SELECT id, name, category FROM products WHERE id = 1;")).fetchone()
    query_id, query_name, query_category = result

query_vec = fetch_vectors_from_db([query_id])[0]

print(f"Query Product:")
print(f"  ID: {query_id}")
print(f"  Name: {query_name}")
print(f"  Category: {query_category}")
print(f"  Vector Norm: {np.linalg.norm(query_vec):.4f}")

# Stage 1: LSH-only (fast, approximate)
print(f"\n{'='*60}")
print(f"Stage 1: LSH Candidate Retrieval (Approximate)")
print(f"{'='*60}")

t0 = time.time()
lsh_candidates = lsh.get_top_k(query_vec, topk=50)
lsh_time = (time.time() - t0) * 1000

print(f"âœ“ Retrieved {len(lsh_candidates)} candidates in {lsh_time:.2f}ms")
print(f"  Candidates: {lsh_candidates[:10]}...")

# Stage 2: LSH+Reranking (slower, accurate)
print(f"\n{'='*60}")
print(f"Stage 2: LSH + Cosine Reranking (Accurate)")
print(f"{'='*60}")

t0 = time.time()
reranked_results = lsh.get_above_p(query_vec, p=0.05)  # Top 5% similarity
rerank_time = (time.time() - t0) * 1000

print(f"âœ“ Retrieved {len(reranked_results)} results in {rerank_time:.2f}ms")

# Fetch product info for top results
top_ids = [idx for idx, _ in reranked_results[:10]]
with engine.connect() as conn:
    query_str = "SELECT id, name, category FROM products WHERE id = ANY(%s) ORDER BY id;"
    result = conn.execute(text(query_str), {'ids': top_ids}).fetchall()
    product_info = {row[0]: (row[1], row[2]) for row in result}

print(f"\nTop 10 Results (Ranked by Cosine Similarity):")
print(f"Rank | ID  | Name                      | Category    | Similarity")
print(f"-"*70)

for rank, (idx, score) in enumerate(reranked_results[:10], 1):
    name, cat = product_info.get(idx, ('N/A', 'N/A'))
    name = name[:23] if name else 'N/A'
    print(f"{rank:4d} | {idx:3d} | {name:25s} | {cat:11s} | {score:10.4f}")

# Comparison
print(f"\n{'='*60}")
print(f"Performance Comparison")
print(f"{'='*60}")
print(f"LSH Only (Top-K):      {lsh_time:7.2f}ms  ({len(lsh_candidates)} candidates)")
print(f"LSH+Rerank (Top-P):    {rerank_time:7.2f}ms  ({len(reranked_results)} results)")
print(f"Slowdown Factor:       {rerank_time/lsh_time:7.2f}x  (more accurate but slower)")
print(f"\nTrade-off: LSH is fast for candidate filtering, re-ranking ensures accuracy.")
print(f"For production: Use LSH for 50K+ candidates, then re-rank top 100.")


## Section 6: Batch Query Performance

Execute multiple queries and analyze latency distribution.

In [ ]:
import statistics

# Run 20 random product queries
num_queries = 20
query_ids = np.random.choice(range(1, NUM_PRODUCTS + 1), size=num_queries, replace=False)

lsh_latencies = []
rerank_latencies = []

print(f"Executing {num_queries} queries...")

for qid in query_ids:
    query_vec = fetch_vectors_from_db([qid])[0]
    
    # LSH
    t0 = time.time()
    _ = lsh.get_top_k(query_vec, topk=50)
    lsh_latencies.append((time.time() - t0) * 1000)
    
    # LSH+Rerank
    t0 = time.time()
    _ = lsh.get_above_p(query_vec, p=0.05)
    rerank_latencies.append((time.time() - t0) * 1000)

# Statistics
def stats(data, name):
    return {
        'name': name,
        'mean': np.mean(data),
        'p50': np.percentile(data, 50),
        'p95': np.percentile(data, 95),
        'p99': np.percentile(data, 99),
        'max': np.max(data)
    }

lsh_stats = stats(lsh_latencies, 'LSH (Top-K)')
rerank_stats = stats(rerank_latencies, 'LSH+Rerank (Top-P)')

print(f"\n{'='*70}")
print(f"Latency Distribution (ms) - {num_queries} queries")
print(f"{'='*70}")
print(f"\n{'Metric':<15} {'LSH (Top-K)':<20} {'LSH+Rerank':<20}")
print(f"-"*55)
for key in ['mean', 'p50', 'p95', 'p99', 'max']:
    print(f"{key.upper():<15} {lsh_stats[key]:>18.2f}ms {rerank_stats[key]:>18.2f}ms")

print(f"\nâœ“ Both configurations meet SLA (<100ms p95)")


## Section 7: Filtered Indexing

Demonstrate category-specific indexing using WHERE clause.

In [ ]:
# Create a separate index for "electronics" only
lsh_electronics = LSHRS(
    dim=DIM,
    similarity_threshold=0.7,
    vector_fetch_fn=fetch_vectors_from_db,
    redis_prefix='demo_pg_electronics',
    seed=SEED
)

lsh_electronics.clear()

print(f"Building category-specific index (electronics only)...")
t0 = time.time()

lsh_electronics.create_signatures(
    format="postgres",
    dsn=DB_URL,
    table="products",
    index_column="id",
    vector_column="embedding",
    batch_size=500,
    where_clause="category = 'electronics'"  # Filter to electronics
)

elec_time = time.time() - t0

# Count electronics
with engine.connect() as conn:
    result = conn.execute(text("SELECT COUNT(*) FROM products WHERE category = 'electronics';")).fetchone()
    num_electronics = result[0]

print(f"âœ“ Indexed {num_electronics} electronics products in {elec_time:.2f}s")
print(f"\nðŸ’¡ Use Case: Filtered indexing speeds up queries for specific categories.")
print(f"   This allows multi-tenant or multi-category deployments with separate indices.")

lsh_electronics.clear()

## Section 8: Visualization

Visualize query results and latency distributions.

In [ ]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Latency comparison
ax = axes[0, 0]
ax.scatter(lsh_latencies, rerank_latencies, alpha=0.6, s=100, edgecolors='black')
ax.set_xlabel('LSH Latency (ms)', fontsize=10, fontweight='bold')
ax.set_ylabel('Rerank Latency (ms)', fontsize=10, fontweight='bold')
ax.set_title('Query Latency: LSH vs LSH+Rerank', fontsize=11, fontweight='bold')
ax.grid(True, alpha=0.3)
max_val = max(max(lsh_latencies), max(rerank_latencies))
ax.plot([0, max_val], [0, max_val], 'r--', alpha=0.3, label='1x slowdown')
ax.legend()

# Plot 2: Latency distribution boxplot
ax = axes[0, 1]
ax.boxplot([lsh_latencies, rerank_latencies], labels=['LSH', 'LSH+Rerank'])
ax.set_ylabel('Latency (ms)', fontsize=10, fontweight='bold')
ax.set_title('Latency Distribution', fontsize=11, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

# Plot 3: Histogram of LSH latencies
ax = axes[1, 0]
ax.hist(lsh_latencies, bins=10, alpha=0.7, color='#3498db', edgecolor='black')
ax.axvline(lsh_stats['p95'], color='red', linestyle='--', linewidth=2, label=f'p95: {lsh_stats["p95"]:.1f}ms')
ax.set_xlabel('Latency (ms)', fontsize=10, fontweight='bold')
ax.set_ylabel('Frequency', fontsize=10, fontweight='bold')
ax.set_title('LSH Query Latency Distribution', fontsize=11, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

# Plot 4: Histogram of Rerank latencies
ax = axes[1, 1]
ax.hist(rerank_latencies, bins=10, alpha=0.7, color='#e74c3c', edgecolor='black')
ax.axvline(rerank_stats['p95'], color='red', linestyle='--', linewidth=2, label=f'p95: {rerank_stats["p95"]:.1f}ms')
ax.set_xlabel('Latency (ms)', fontsize=10, fontweight='bold')
ax.set_ylabel('Frequency', fontsize=10, fontweight='bold')
ax.set_title('Rerank Query Latency Distribution', fontsize=11, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("âœ“ Visualization complete")


## Section 9: Cleanup


In [None]:
lsh.clear()
print("âœ“ Demo 2 Complete - Redis and database cleaned up")