# Vector Database Operations with pgvector

In [2]:
import os
import psycopg2
from docarray import BaseDoc, DocList
from docarray.typing import NdArray
from langchain_openai import OpenAIEmbeddings
from dotenv import load_dotenv
from pgvector.psycopg2 import register_vector


# Load environment variables from the .env file
load_dotenv()

# Database connection
DATABASE_URL = os.environ["DATABASE_URL"]
conn = psycopg2.connect(DATABASE_URL)
register_vector(conn)
cur = conn.cursor()

In [None]:
# Rollback db transaction if error occurs
conn.rollback()

In [3]:
# Query drink data from drink tables
cur.execute("""
SELECT 
  d.id AS drink_id,
  d.name AS drink_name,
  STRING_AGG(DISTINCT t.name, ', ') AS tags,
  STRING_AGG(DISTINCT i.name, ', ') AS ingredients
FROM drinks d
LEFT JOIN drink_tags dt ON d.id = dt.drink_id
LEFT JOIN tags t ON dt.tag_id = t.id
LEFT JOIN drink_ingredients di ON d.id = di.drink_id
LEFT JOIN ingredients i ON di.ingredient_id = i.id
WHERE d.reference like 'http%'
GROUP BY d.id, d.name
ORDER BY d.name
""")
drinks_query_res = cur.fetchall()

In [4]:
# Verify query result is as expected
for d in drinks_query_res[:5]:
    print(d)

(315, '(Beware the) Pink Slip', 'aromatic, herbal, spicy, zesty', 'absinthe rouge, dry vermouth, gin, ginger liqueur, honey, lime')
(3185, '(Twice) Improved Whiskey Sour', 'aromatic, rich, smooth, tart', 'bitters, egg white, honey syrup, lemon juice, lemon peel, lime juice, maraschino liqueur, rye')
(1433, '(the) Hinges', 'aromatic, bitter, bold, herbal', 'campari, gin, herbal liqueur, irish whiskey, orange bitters, sweet vermouth')
(18, '100-Year-Old Cigar', 'aromatic, bitter, rich, smoky', 'absinthe, añejo rum, bitters, bénédictine, cynar, islay scotch')
(15, '15 Second Punch', 'bitter, floral, vibrant, zesty', 'campari, elderflower liqueur, gin, grapefruit juice, lemon juice')


In [11]:
# Create a list of drink objects
drinks = [
    {
        "drink_id": int(d[0]),
        "drink_name": d[1],
        "ingredients": d[3],
        "tags": d[2],
        "drink_description": f"{d[1]}: a {d[2]} cocktail made with {d[3]}",
    }
    for d in drinks_query_res
]

In [None]:
embeddings = OpenAIEmbeddings()


# Define vector db schema
class DrinkDoc(BaseDoc):
    drink_id: int
    drink_name: str
    ingredients: str
    tags: str
    embedding: NdArray[1536]


# Embed description and create documents
docs = DocList[DrinkDoc](
    DrinkDoc(
        drink_id=drink["drink_id"],
        drink_name=drink["drink_name"],
        ingredients=drink["ingredients"],
        tags=drink["tags"],
        embedding=embeddings.embed_query(drink["drink_description"]),
    )
    for drink in drinks
)

In [None]:
# Create vector table in database
create_table_command = """
CREATE TABLE embeddings (
    id bigserial primary key,
    drink_id integer unique not null,
    drink_name text unique not null,
    ingredients text,
    tags text,
    embedding vector(1536)
);
"""

# Execute the SQL command
cur.execute(create_table_command)

# Commit the transaction
conn.commit()

In [None]:
# Store documents with embeddings in the database
for doc in docs:
    # Insert embedding into database
    cur.execute(
        """
        INSERT INTO embeddings (drink_id, drink_name, ingredients, tags, embedding)
        VALUES (%s, %s, %s, %s, %s)
        ON CONFLICT (drink_id) DO UPDATE
        SET embedding = EXCLUDED.embedding
    """,
        (doc.drink_id, doc.drink_name, doc.ingredients, doc.tags, doc.embedding),
    )

conn.commit()

In [32]:
# Example similarity search
# <-> - L2 distance
# <=> - cosine distance
# <+> - L1 distance
def find_similar_drinks(query_text: str, limit: int = 10):
    query_embedding = embeddings.embed_query(query_text)

    cur.execute(
        """
        SELECT d.name, e.embedding <=> %s::vector as distance
        FROM embeddings e
        JOIN drinks d ON d.id = e.drink_id
        ORDER BY distance
        LIMIT %s
    """,
        (query_embedding, limit),
    )

    return cur.fetchall()

In [37]:
# Test similarity search
similar_drinks = find_similar_drinks(
    "Ingredients: sweet, whisky, herbal Mood: celebratory, party",
    limit=3,
)
print("Similar drinks:", similar_drinks)

Similar drinks: [('Holidays Away', 0.13440680503845215), ('Campbeltown', 0.14333091457713065), ('Traveling Scotsman', 0.14390814304351807)]


In [100]:
# Clean up
cur.close()
conn.close()