In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
#!docker run -d --name pgvector-c -e POSTGRES_PASSWORD=mysecretpassword -p 2345:5432 postgres-with-pgvector

In [3]:
%%writefile utils/cosine_dist.py

# Cosine distance is the simplest operation!

import numpy as np

def cos_dist(e1,e2):
    return 1 - np.dot(e1,e2) / np.sqrt(np.dot(e1,e1) * np.dot(e2, e2))

Overwriting utils/cosine_dist.py


In [4]:
from utils.cosine_dist import cos_dist

In [5]:
EMBEDDINGS_LEN = 1536  # Openai size

NUM_DOCS = 2*10**5

print(f'Number of documents: {NUM_DOCS:,}')
print(f'More than {4*NUM_DOCS*EMBEDDINGS_LEN / 10**9: .2} GB (size of embeddings only) will be stored in the table')


Number of documents: 200,000
More than  1.2 GB (size of embeddings only) will be stored in the table


In [6]:
from pgvector.sqlalchemy import Vector
from sqlalchemy import create_engine, insert, select, text, Integer, String, Text
from sqlalchemy.orm import declarative_base, mapped_column, Session

engine = create_engine('postgresql+psycopg://postgres:mysecretpassword@localhost:2345/postgres')
with engine.connect() as conn:
    conn.execute(text('CREATE EXTENSION IF NOT EXISTS vector'))
    conn.commit()

Base = declarative_base()


class Document(Base):
    __tablename__ = 'document'
    
    id = mapped_column(Integer, primary_key=True)
    content = mapped_column(Text)
    embedding = mapped_column(Vector(EMBEDDINGS_LEN))


Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)
session = Session(engine)


In [7]:
list(
    session.execute(text("SELECT count(*) from document"))
)

[(0,)]

In [8]:
# Will insert in batches
BATCH_SIZE =  10**3


In [10]:
import numpy as np

centers = np.random.rand(NUM_DOCS//BATCH_SIZE,  EMBEDDINGS_LEN, )  
centers.shape

(200, 1536)

In [11]:
error = np.random.randn(BATCH_SIZE, EMBEDDINGS_LEN) / 20
error.shape


(1000, 1536)

In [15]:
def lines_to_stringio(lines):
    f = io.StringIO()

    for l in lines:
        f.write(l)
    f.seek(0)

    return f

conn = engine.raw_connection()
with conn.cursor() as cursor:
    res = cursor.execute('''
        SELECT COUNT(*) from document;
    ''')
    print(list(res))

    # cursor.copy_expert(f"""
    #     COPY {target_table} ({column_names})
    #     FROM STDIN WITH CSV NULL '\\N'
    # """, lines_to_stringio)

conn.close()

# Write docs to DB

In [12]:
%%time
from tqdm.notebook import trange, tqdm

for i in trange(centers.shape[0]):
    
    embeddings = (centers[i] + error).tolist()  # batch_size x EMBEDDINGS_LEN
    # print(embeddings.shape)
    
    # idx = np.arange(i*BATCH_SIZE, (i+1)*BATCH_SIZE)
    
    idx = list(range(i*BATCH_SIZE, (i+1)*BATCH_SIZE ))
    
    params = (
        dict(
            id=idx[j],
            content=f'some unique content', #{idx[j]}',
            embedding=embeddings[j]
        )
        for j in range(BATCH_SIZE)
    )
    session.bulk_insert_mappings(Document, params)
    
    session.flush()
session.commit()
    
    # session.add_all((
    #     Document(id=i*BATCH_SIZE + j, content=f'some unique content #{i*BATCH_SIZE + j}', embedding=embeddings[j, :])
    #     for j in range(BATCH_SIZE)
    # ))

  0%|          | 0/200 [00:00<?, ?it/s]

CPU times: user 1min 55s, sys: 3.19 s, total: 1min 59s
Wall time: 2min 52s


In [14]:
session.bul

<sqlalchemy.orm.session.Session at 0x10ac33250>

In [14]:
list(
    session.execute(text("SELECT reltuples AS estimate FROM pg_class WHERE relname = 'document'"))
)

[(200000.0,)]

In [15]:
list(
    session.execute(text("SELECT count(*) from document"))
)

[(200000,)]

# Search 

## Baseline

In [16]:
# Baseline communication time  - retrieve by IDs
import random

x = random.randint(0, centers.shape[0])
print(x)

# session.scalars(select(Document).order_by(Document.embedding.cosine_distance(doc_X_embeddings)).limit(10))
def retrieve_by_ids():
    return session.query(Document).filter(Document.id.in_(range(x, x+10))).all()
%timeit retrieve_by_ids()
docs_found = retrieve_by_ids()
print([
    d.id
    for d in docs_found
])

104
3.12 ms ± 160 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
[104, 105, 106, 107, 108, 109, 110, 111, 112, 113]


## Search by cosine dist. (witbout index)

In [17]:
session.execute(text('DROP INDEX IF EXISTS my_index'))
session.commit()

In [18]:
import random
from tabulate import tabulate
# from utils.cosine_dist import cos_dist

x = random.randint(0, centers.shape[0])
doc_X_embeddings = centers[x]

def search(doc_X_embeddings):
    return session.scalars(select(Document).order_by(Document.embedding.cosine_distance(doc_X_embeddings)).limit(10))

%timeit search(doc_X_embeddings)

docs_found = search(doc_X_embeddings)
print(f'For center #{x}, the closest docs found:')
print(tabulate([
    {'ID': d.id, 'Dist.': cos_dist(d.embedding, doc_X_embeddings)}
    for d in docs_found
], headers='keys'))


604 ms ± 34.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
For center #188, the closest docs found:
    ID       Dist.
------  ----------
188755  0.00313165
188217  0.00321476
188315  0.00324013
188117  0.0032782
188339  0.00328146
188390  0.00330763
188717  0.00331707
188493  0.0033199
188577  0.00332645
188343  0.00335264


## Create index

In [25]:
%%time
from sqlalchemy import Index
session.execute(text('DROP INDEX IF EXISTS my_index'))
index = Index('my_index', Document.embedding,
    postgresql_using='ivfflat',
    postgresql_with={'lists': 200},
    postgresql_ops={'embedding': 'vector_cosine_ops'}
)
index.create(engine)

KeyboardInterrupt: 

## Search with index

In [24]:
%timeit search(doc_X_embeddings)
from tabulate import tabulate


docs_found = search(doc_X_embeddings)
print(f'For center #{x}, the closest docs found:')
print(tabulate([
    {'ID': d.id, 'Dist.': cos_dist(d.embedding, doc_X_embeddings)}
    for d in docs_found
], headers='keys'))

4.15 ms ± 69.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
For center #188, the closest docs found:
    ID       Dist.
------  ----------
188755  0.00313165
188217  0.00321476
188315  0.00324013
188117  0.0032782
188339  0.00328146
188717  0.00331707
188493  0.0033199
188577  0.00332645
188343  0.00335264
188558  0.00336192
