In [30]:
%pip install InstructorEmbedding ipywidgets langchain psycopg2-binary python-dotenv scikit-learn sentence_transformers tqdm  

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Collecting python-dotenv
  Using cached python_dotenv-1.0.0-py3-none-any.whl (19 kB)
Installing collected packages: python-dotenv
Successfully installed python-dotenv-1.0.0
Note: you may need to restart the kernel to use updated packages.


In [31]:
import os
from dotenv import load_dotenv

load_dotenv()

True

In [46]:
import psycopg2

try:
    # todo move to env vars
    connection = psycopg2.connect(
        os.environ["DATABASE_URL"]
    )
    print("Connection established successfully!")
except psycopg2.Error as e:
    print(f"Error connecting to PostgreSQL: {e}")

Connection established successfully!


Prepare the database by creating the required table

In [5]:
with connection.cursor() as cursor:
    cursor.execute("DROP TABLE IF EXISTS dataset;")
    cursor.execute("CREATE TABLE dataset(id SERIAL PRIMARY KEY, text TEXT, embeddings DOUBLE PRECISION[]);")
    connection.commit()

We need a similarity function. I'm using cosine similarity here. I'm not sure if that's the best choice, but it's a good starting point. As a reference point, I'll be creating my own similarity function, we can meausre against built in functions in pgvector.

In [6]:
with connection.cursor() as cursor:
    cursor.execute("DROP FUNCTION IF EXISTS cosine_distance(a DOUBLE PRECISION[], b DOUBLE PRECISION[]);")
    cursor.execute("""
        CREATE FUNCTION cosine_distance(a DOUBLE PRECISION[], b DOUBLE PRECISION[]) RETURNS DOUBLE PRECISION AS $$
        DECLARE
            dot DOUBLE PRECISION := 0;
            mag_a DOUBLE PRECISION := 0;
            mag_b DOUBLE PRECISION := 0;
            i INTEGER := 1;
        BEGIN
            WHILE i <= array_length(a, 1) LOOP
                dot := dot + a[i] * b[i];
                mag_a := mag_a + a[i] * a[i];
                mag_b := mag_b + b[i] * b[i];
                i := i + 1;
            END LOOP;
            RETURN 1 - (dot / sqrt(mag_a * mag_b));
        END;
        $$ LANGUAGE plpgsql;

    """)
    connection.commit()

We need data to embed. I'm using a small corpus of the top 5 books from the Gutenberg project. The data is in the data/ directory.

In [7]:
# read the corpus
with open("data/gutenberg-top-5.txt", "r") as f:
    corpus = f.read()
    # split into paragraphs
    corpus = corpus.split("\n\n")
    # remove blank lines
    corpus = [line for line in corpus if line.strip() != ""]
    print(f"Corpus length: {len(corpus)}")

Corpus length: 26776


Next we need to create the embeddings for the data. I'm using the `HuggingFaceInstructEmbeddings` model here not because it's the best, but because I can run it locally. Sine I'm on  

In [8]:
from langchain.embeddings import HuggingFaceInstructEmbeddings
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps:0")
else:
    print("MPS device not found.")
    device = torch.device("cpu")


hf = HuggingFaceInstructEmbeddings(
    model_name="hkunlp/instructor-large",
    model_kwargs={'device': device},
    encode_kwargs={'normalize_embeddings': True}
)

  from tqdm.autonotebook import trange


load INSTRUCTOR_Transformer
max_seq_length  512


In [9]:
from tqdm.notebook import tqdm
with connection.cursor() as cursor:
    for line in tqdm(corpus[:1000]):
        embedding = hf.embed_documents(line)[0]
        cursor.execute("INSERT INTO dataset(text, embeddings) VALUES (%s, %s)", (line, embedding))
        connection.commit()

  0%|          | 0/1000 [00:00<?, ?it/s]

  assert torch.sum(attention_mask[local_idx]).item() >= context_masks[local_idx].item(),\


What we want to figure out here is how fast we can retrieve similarly rated content for given text.
* What sort of indexes do we need to create to make this fast?
* How fast can we make it?
* Does partitioning the data help?
* How much does the pgvector extension help?
* Does clustering the embeddings help

I'm using embeddings from a random query here as an example. In an ideal usecase you should be creating your own embeddings from your query

In [10]:

test_query = f"""
    WITH target_item AS (
        SELECT * 
        FROM dataset 
        WHERE id=60
    )
    SELECT ds.text, cosine_distance(ds.embeddings, target_item.embeddings) as distance
    FROM dataset ds, target_item 
    ORDER BY distance DESC 
    LIMIT 20;
"""

In [11]:
with connection.cursor() as cursor:
    cursor.execute(test_query)
    for line in cursor.fetchall():
        text = line[0].strip().replace("\n", " ")
        score = line[1]
        print(f"{score}: {text}")

0.24268183173368: CHORUS. Now old desire doth in his deathbed lie, And young affection gapes to be his heir; That fair for which love groan’d for and would die, With tender Juliet match’d, is now not fair. Now Romeo is belov’d, and loves again, Alike bewitched by the charm of looks; But to his foe suppos’d he must complain, And she steal love’s sweet bait from fearful hooks: Being held a foe, he may not have access To breathe such vows as lovers use to swear; And she as much in love, her means much less To meet her new beloved anywhere. But passion lends them power, time means, to meet, Tempering extremities with extreme sweet.
0.24268183173368: CAPULET. So many guests invite as here are writ.
0.24268183173368: CAPULET. Why how now, kinsman! Wherefore storm you so?
0.24268183173368: CAPULET. Will you tell me that? His son was but a ward two years ago.
0.24268183173368: CAPULET. Nay, gentlemen, prepare not to be gone, We have a trifling foolish banquet towards. Is it e’en so? Why then, 

The resulting distances aren't the best looking however there is some resemblance in the text. Let's analyze the query plan to see what's going on.

In [12]:
with connection.cursor() as cursor:
    cursor.execute("EXPLAIN ANALYZE " + test_query)
    for line in cursor.fetchall():
        print(line[0])

Limit  (cost=319.73..319.78 rows=20 width=135) (actual time=581.897..581.901 rows=20 loops=1)
  ->  Sort  (cost=319.73..322.21 rows=992 width=135) (actual time=581.896..581.897 rows=20 loops=1)
        Sort Key: (cosine_distance(ds.embeddings, dataset.embeddings)) DESC
        Sort Method: top-N heapsort  Memory: 36kB
        ->  Nested Loop  (cost=0.28..293.33 rows=992 width=135) (actual time=0.644..581.407 rows=1000 loops=1)
              ->  Index Scan using dataset_pkey on dataset  (cost=0.28..2.49 rows=1 width=18) (actual time=0.012..0.013 rows=1 loops=1)
                    Index Cond: (id = 60)
              ->  Seq Scan on dataset ds  (cost=0.00..32.92 rows=992 width=145) (actual time=0.010..0.270 rows=1000 loops=1)
Planning Time: 0.114 ms
Execution Time: 581.946 ms


Let's add an index to the embeddings column and see if that helps.

In [13]:
with connection.cursor() as cursor:
    cursor.execute("DROP TABLE IF EXISTS indexed_dataset;")
    cursor.execute("DROP INDEX IF EXISTS indexed_dataset_embeddings_idx;")

    cursor.execute("CREATE TABLE indexed_dataset AS SELECT * from dataset;")
    cursor.execute("CREATE INDEX indexed_dataset_embeddings_idx ON indexed_dataset USING GIN(embeddings);")
    
    connection.commit()

In [17]:
test_indexed_query = f"""
  WITH target_item AS (
      SELECT * 
      FROM indexed_dataset 
      WHERE id=60
  )
  SELECT ds.text, cosine_distance(ds.embeddings, target_item.embeddings) as distance
  FROM indexed_dataset ds, target_item 
  ORDER BY distance DESC 
  LIMIT 20;
"""

In [20]:
with connection.cursor() as cursor:
    cursor.execute("EXPLAIN ANALYZE " + test_indexed_query)
    for line in cursor.fetchall():
        print(line[0])

Limit  (cost=355.11..355.16 rows=20 width=134) (actual time=591.074..591.078 rows=20 loops=1)
  ->  Sort  (cost=355.11..357.61 rows=1000 width=134) (actual time=591.073..591.074 rows=20 loops=1)
        Sort Key: (cosine_distance(ds.embeddings, indexed_dataset.embeddings)) DESC
        Sort Method: top-N heapsort  Memory: 36kB
        ->  Nested Loop  (cost=0.00..328.50 rows=1000 width=134) (actual time=1.624..590.544 rows=1000 loops=1)
              ->  Seq Scan on indexed_dataset  (cost=0.00..35.50 rows=1 width=18) (actual time=0.060..0.184 rows=1 loops=1)
                    Filter: (id = 60)
                    Rows Removed by Filter: 999
              ->  Seq Scan on indexed_dataset ds  (cost=0.00..33.00 rows=1000 width=144) (actual time=0.002..0.336 rows=1000 loops=1)
Planning Time: 0.860 ms
Execution Time: 591.280 ms


Looks like having an index doesn't improve the query all that much because it needs to scan the entire table to to compute the distances.
One way to mitigate this is to cluster the table on the embeddings column. This will group similar embeddings together on disk and make it faster to scan the table. We can't use the `CLUSTER` feature because it's not supported on GIN indexes, however we can pre-compute the clusters using a K-Means clustering algorithm.

In [21]:
with connection.cursor() as cursor:
    cursor.execute("SELECT embeddings FROM dataset;")
    embeddings = cursor.fetchall()
    embeddings = [e[0] for e in embeddings]

In [23]:
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=5, random_state=0, n_init="auto").fit(embeddings)
cluster_data = {
    'clusters': kmeans.labels_.tolist(),
    'centroids': kmeans.cluster_centers_.tolist(),
    'iterations': kmeans.n_iter_,
    'inertia': kmeans.inertia_,
    'converged': kmeans.n_iter_ < kmeans.max_iter,
    'counts': {
        '0': (kmeans.labels_ == 0).sum(),
        '1': (kmeans.labels_ == 1).sum(),
        '2': (kmeans.labels_ == 2).sum(),
        '3': (kmeans.labels_ == 3).sum(),
        '4': (kmeans.labels_ == 4).sum(),
    }
}

print(cluster_data['counts'])

{'0': 564, '1': 94, '2': 159, '3': 116, '4': 67}


In [24]:
with connection.cursor() as cursor:
    cursor.execute("DROP TABLE IF EXISTS clusters;")
    cursor.execute("CREATE TABLE clusters(id INTEGER, centroid double precision[]);")
    for i, centroid in enumerate(cluster_data['centroids']):
        cursor.execute("INSERT INTO clusters(id, centroid) VALUES (%s, %s)", (i, centroid))
    connection.commit()

Now that we have a cluster map we can add a cluster index to the table.

In [25]:
with connection.cursor() as cursor:
    cursor.execute("DROP TABLE IF EXISTS clustered_dataset;")
    cursor.execute("CREATE TABLE clustered_dataset AS SELECT * from dataset;")
    cursor.execute("ALTER TABLE clustered_dataset ADD COLUMN cluster INTEGER;")

    for i, cluster in enumerate(cluster_data['clusters']):
        cursor.execute("UPDATE clustered_dataset SET cluster=%s WHERE id=%s", (cluster, i+1))

    connection.commit()

Now querying data happens in 2 steps
1. Find the cluster that the query text belongs to
2. Find the nearest neighbors within that cluster

In [38]:
clustered_test_query = """
    EXPLAIN ANALYZE WITH target_item AS (
    SELECT id, embeddings, cluster
    FROM clustered_dataset
    WHERE id = 60
    ),
    nearest_centroid AS (
    SELECT c.id AS centroid_id, c.centroid
    FROM clusters c
    ORDER BY cosine_distance(c.centroid, (SELECT embeddings FROM target_item)) DESC
    LIMIT 1
    )
    SELECT ds.id, ds.text
    FROM clustered_dataset ds
    WHERE ds.cluster = (SELECT centroid_id FROM nearest_centroid)
    AND ds.id <> 60
    ORDER BY cosine_distance(ds.embeddings, (SELECT embeddings FROM target_item)) DESC
    LIMIT 20;
"""

In [39]:
with connection.cursor() as cursor:
    cursor.execute(clustered_test_query)
    for line in cursor.fetchall():
        print(line[0])

Limit  (cost=519.43..519.48 rows=20 width=138) (actual time=96.114..96.122 rows=20 loops=1)
  CTE target_item
    ->  Seq Scan on clustered_dataset  (cost=0.00..57.50 rows=1 width=26) (actual time=0.364..0.510 rows=1 loops=1)
          Filter: (id = 60)
          Rows Removed by Filter: 999
  CTE nearest_centroid
    ->  Limit  (cost=346.57..346.57 rows=1 width=44) (actual time=4.475..4.477 rows=1 loops=1)
          InitPlan 2 (returns $1)
            ->  CTE Scan on target_item  (cost=0.00..0.02 rows=1 width=32) (actual time=0.366..0.512 rows=1 loops=1)
          ->  Sort  (cost=346.55..349.73 rows=1270 width=44) (actual time=4.474..4.475 rows=1 loops=1)
                Sort Key: (cosine_distance(c.centroid, $1)) DESC
                Sort Method: top-N heapsort  Memory: 25kB
                ->  Seq Scan on clusters c  (cost=0.00..340.20 rows=1270 width=44) (actual time=2.169..4.463 rows=5 loops=1)
  InitPlan 4 (returns $3)
    ->  CTE Scan on target_item target_item_1  (cost=0.00..0.0

This has greatly improved the performance. We've been able to eliminate the need to scan the entire table (it's only looking at 563 rows). The performance here will be impacted by the number of clusters you are able to generate. however this does mean that you may be excluding results that are somewhat close but in neibooring clusters. One approach to mitigate this is to query the nearest neighbors from the neighboring clusters as well.

In [40]:
with connection.cursor() as cursor:
    cursor.execute("DROP INDEX IF EXISTS clustered_dataset_pkey;")
    cursor.execute("DROP INDEX IF EXISTS idx_clustered_dataset_cluster;")
    cursor.execute("DROP INDEX IF EXISTS idx_clusters_centroid;")
    cursor.execute("CREATE UNIQUE INDEX clustered_dataset_pkey ON clustered_dataset(id int4_ops);")
    cursor.execute("CREATE INDEX idx_clustered_dataset_cluster ON clustered_dataset (cluster);")
    cursor.execute("CREATE INDEX idx_clusters_centroid ON clusters USING GIN (centroid);")
    connection.commit()


In [41]:
with connection.cursor() as cursor:
    cursor.execute(clustered_test_query)
    for line in cursor.fetchall():
        print(line[0])

Limit  (cost=104.43..104.48 rows=20 width=138) (actual time=97.251..97.258 rows=20 loops=1)
  CTE target_item
    ->  Index Scan using clustered_dataset_pkey on clustered_dataset  (cost=0.28..2.49 rows=1 width=26) (actual time=0.031..0.032 rows=1 loops=1)
          Index Cond: (id = 60)
  CTE nearest_centroid
    ->  Limit  (cost=2.34..2.35 rows=1 width=44) (actual time=3.091..3.093 rows=1 loops=1)
          InitPlan 2 (returns $1)
            ->  CTE Scan on target_item  (cost=0.00..0.02 rows=1 width=32) (actual time=0.033..0.034 rows=1 loops=1)
          ->  Sort  (cost=2.32..2.34 rows=5 width=44) (actual time=3.090..3.091 rows=1 loops=1)
                Sort Key: (cosine_distance(c.centroid, $1)) DESC
                Sort Method: top-N heapsort  Memory: 25kB
                ->  Seq Scan on clusters c  (cost=0.00..2.30 rows=5 width=44) (actual time=0.684..3.082 rows=5 loops=1)
  InitPlan 4 (returns $3)
    ->  CTE Scan on target_item target_item_1  (cost=0.00..0.02 rows=1 width=32) (

Adding indexes allows us to replace some sequential scans with index scans. This isn't that big of an improvement here but would be more significant with a larger dataset.

We can also use the pgvector extension to create a GIN index on the embeddings column. This will allow us to use the `pgvector <-> pgvector` operator to find the nearest neighbors.

In [47]:
with connection.cursor() as cursor:
    cursor.execute("DROP TABLE IF EXISTS vectorized_dataset;")
    cursor.execute("CREATE EXTENSION IF NOT EXISTS vector WITH SCHEMA extensions;");
    cursor.execute("CREATE TABLE vectorized_dataset AS SELECT * from dataset;")
    
    # 768 seems to be the default length for the instructor-large model
    cursor.execute("ALTER TABLE vectorized_dataset ADD COLUMN embeddings_vector vector(768);")
    connection.commit()


Let's copy over our data to the new table and see how it performs.

In [48]:
with connection.cursor() as cursor:
    cursor.execute("SELECT embeddings FROM vectorized_dataset;")
    embeddings = cursor.fetchall()
    embeddings = [e[0] for e in embeddings]
    for i, embedding in enumerate(embeddings):
        cursor.execute("UPDATE vectorized_dataset SET embeddings_vector=%s WHERE id=%s", (embedding, i+1))
    connection.commit()
    

In [49]:
vectorized_test_query = """
    WITH target_item AS (
        SELECT * 
        FROM vectorized_dataset 
        WHERE id=60
    )
    SELECT ds.text, ds.embeddings_vector <=> target_item.embeddings_vector as score
    FROM vectorized_dataset ds, target_item
    ORDER BY score DESC
    LIMIT 20;
"""

In [50]:
with connection.cursor() as cursor:
    cursor.execute(vectorized_test_query)
    for line in cursor.fetchall():
        text = line[0].strip().replace("\n", " ")
        score = line[1]
        print(f"{score}: {text}")

0.242681830496018: CHORUS. Now old desire doth in his deathbed lie, And young affection gapes to be his heir; That fair for which love groan’d for and would die, With tender Juliet match’d, is now not fair. Now Romeo is belov’d, and loves again, Alike bewitched by the charm of looks; But to his foe suppos’d he must complain, And she steal love’s sweet bait from fearful hooks: Being held a foe, he may not have access To breathe such vows as lovers use to swear; And she as much in love, her means much less To meet her new beloved anywhere. But passion lends them power, time means, to meet, Tempering extremities with extreme sweet.
0.242681830496018: ROMEO. Yet banished? Hang up philosophy. Unless philosophy can make a Juliet, Displant a town, reverse a Prince’s doom, It helps not, it prevails not, talk no more.
0.242681830496018: TYBALT. This by his voice, should be a Montague. Fetch me my rapier, boy. What, dares the slave Come hither, cover’d with an antic face, To fleer and scorn at o

In [51]:
with connection.cursor() as cursor:
    cursor.execute("EXPLAIN ANALYZE" + vectorized_test_query)
    for line in cursor.fetchall():
        print(line[0])

Limit  (cost=155.61..155.66 rows=20 width=134) (actual time=7.986..7.991 rows=20 loops=1)
  ->  Sort  (cost=155.61..158.11 rows=1000 width=134) (actual time=7.985..7.987 rows=20 loops=1)
        Sort Key: ((ds.embeddings_vector <=> vectorized_dataset.embeddings_vector)) DESC
        Sort Method: top-N heapsort  Memory: 35kB
        ->  Nested Loop  (cost=0.00..129.00 rows=1000 width=134) (actual time=0.071..7.667 rows=1000 loops=1)
              ->  Seq Scan on vectorized_dataset  (cost=0.00..59.50 rows=1 width=18) (actual time=0.044..0.144 rows=1 loops=1)
                    Filter: (id = 60)
                    Rows Removed by Filter: 999
              ->  Seq Scan on vectorized_dataset ds  (cost=0.00..57.00 rows=1000 width=144) (actual time=0.002..0.119 rows=1000 loops=1)
Planning Time: 0.096 ms
Execution Time: 8.032 ms


The performance difference is staggering. We've gone from 333.352 ms to 8.112 ms. That's a 41x improvement.

In [52]:
with connection.cursor() as cursor:
    cursor.execute("DROP TABLE IF EXISTS indexed_vectorized_dataset;")
    cursor.execute("DROP INDEX IF EXISTS idx_indexed_vectorized_dataset_pkey;")
    cursor.execute(
        "DROP INDEX IF EXISTS idx_indexed_vectorized_dataset_embeddings_vector_idx;"
    )

    cursor.execute(
        "CREATE TABLE indexed_vectorized_dataset AS SELECT * from vectorized_dataset;")
    cursor.execute(
        "CREATE INDEX idx_indexed_vectorized_dataset_pkey ON indexed_vectorized_dataset(id int4_ops);")
    cursor.execute(
        "CREATE INDEX idx_indexed_vectorized_dataset_embeddings_vector ON indexed_vectorized_dataset USING ivfflat (embeddings_vector vector_cosine_ops) WITH (lists = 100);"
    )

    connection.commit()

In [53]:
indexed_vectorized_test_query = """
    WITH target_item AS (
        SELECT * 
        FROM indexed_vectorized_dataset 
        WHERE id=60
    )
    SELECT ds.text, ds.embeddings_vector <=> target_item.embeddings_vector as score
    FROM indexed_vectorized_dataset ds, target_item
    ORDER BY score DESC
    LIMIT 20;
"""

In [60]:
with connection.cursor() as cursor:
    cursor.execute("SET enable_seqscan=false;")
    connection.commit()
    cursor.execute("EXPLAIN ANALYZE" + indexed_vectorized_test_query)
    for line in cursor.fetchall():
        print(line[0])
    cursor.execute("SET enable_seqscan=true;")
    connection.commit()

Limit  (cost=10000000076.60..10000000076.65 rows=20 width=134) (actual time=8.033..8.036 rows=20 loops=1)
  ->  Sort  (cost=10000000076.60..10000000079.10 rows=1000 width=134) (actual time=8.031..8.033 rows=20 loops=1)
        Sort Key: ((ds.embeddings_vector <=> indexed_vectorized_dataset.embeddings_vector)) DESC
        Sort Method: top-N heapsort  Memory: 35kB
        ->  Nested Loop  (cost=10000000000.27..10000000049.99 rows=1000 width=134) (actual time=0.045..7.734 rows=1000 loops=1)
              ->  Index Scan using idx_indexed_vectorized_dataset_pkey on indexed_vectorized_dataset  (cost=0.28..2.49 rows=1 width=18) (actual time=0.012..0.013 rows=1 loops=1)
                    Index Cond: (id = 60)
              ->  Seq Scan on indexed_vectorized_dataset ds  (cost=10000000000.00..10000000035.00 rows=1000 width=144) (actual time=0.008..0.192 rows=1000 loops=1)
Planning Time: 0.117 ms
Execution Time: 8.080 ms
