In [13]:
# DB connection
from google.cloud.sql.connector import Connector
# from pgvector.asyncpg import register_vector
import sqlalchemy
# import asyncio
# import asyncpg

# Data processing
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pandas as pd
import os

# Embeddings
from langchain.embeddings import VertexAIEmbeddings
from vertexai.language_models import TextEmbeddingModel
import time

# Utils
from tqdm import tqdm
import numpy as np

In [5]:
project_id = os.environ.get("PROJECT_ID")
region = "europe-west3"
instance_name = "legalm"

DB_NAME = "pubmed"
DB_USER = "postgres"
DB_PASS = os.environ.get("DB_PASS")
DB_PORT = "5432"

INSTANCE_CONNECTION_NAME = f"{project_id}:{region}:{instance_name}"
print(f"Your instance connection name is: {INSTANCE_CONNECTION_NAME}")

Your instance connection name is: steam-378309:europe-west3:legalm


In [6]:
from google.cloud.sql.connector import Connector

connector = Connector()

def getconn():
    conn = connector.connect(
        INSTANCE_CONNECTION_NAME,
        "pg8000",
        user=DB_USER,
        password=DB_PASS,
        db=DB_NAME
    )
    return conn

pool = sqlalchemy.create_engine(
    "postgresql+pg8000://",
    creator=getconn,
)

In [7]:
pool = sqlalchemy.create_engine(
    "postgresql+pg8000://",
    creator=getconn,
)

In [36]:
DB_SIZE_QUERY = """SELECT pg_database.datname AS database_name
                   ,ROUND((pg_database_size(pg_database.datname) / 1048576)::numeric, 2) AS data_size_MB
                   FROM pg_database;"""

QUERY = """SELECT COUNT(*) FROM products;"""

In [37]:
with pool.connect() as db_conn:
    query = sqlalchemy.text(QUERY)
    results = db_conn.execute(query).fetchall()
    
    for row in results:
        print(row)

(848,)


In [8]:
DATASET_URL = "https://github.com/GoogleCloudPlatform/python-docs-samples/raw/main/cloud-sql/postgres/pgvector/data/retail_toy_dataset.csv"

df = pd.read_csv(DATASET_URL)
df = df.loc[:, ["product_id", "product_name", "description", "list_price"]]
df = df.dropna()

df.head(10)

Unnamed: 0,product_id,product_name,description,list_price
0,7e8697b5b7cdb5a40daf54caf1435cd5,"Koplow Games Set of 2 D12 12-Sided Rock, Paper...","Rock, paper, scissors is a great way to resolv...",3.56
1,7de8b315b3cb91f3680eb5b88a20dcee,"12""-20"" Schwinn Training Wheels",Turn any small bicycle into an instrument for ...,28.17
2,fb9535c103d7d717f0414b2b111cfaaa,Bicycle Pinochle Jumbo Index Playing Cards - 1...,Purchase includes 1 blue deck and 1 red deck. ...,6.49
3,c73ea622b3be6a3ffa3b0b5490e4929e,Step2 Woodland Adventure Playhouse & Slide,The Step2 Woodland Climber Adventure Playhouse...,499.99
4,dec7bd1f983887650715c6fafaa5b593,Step2 Naturally Playful Welcome Home Playhouse...,Children can play and explore in the Step2 Nat...,600.0
5,74a695e3675efc2aad11ed73c46db29b,Slip N Slide Triple Racer with Slide Boogies,Triple Racer Slip and Slide with Boogie Boards...,37.21
6,3eae5293b56c25f63b47cb8a89fb4813,Hydro Tools Digital Pool/Spa Thermometer,The solar-powered Swimline Floating Digital Th...,15.92
7,ed85bf829a36c67042503ffd9b6ab475,Full Bucket Swing With Coated Chain Toddler Sw...,Safe Kids&Children Full Bucket Swing With Coa...,102.26
8,55820fa53f0583cb637d5cb2b051d78c,Banzai Water Park Splash Zone,Dive into fun in your own backyard with the B...,397.82
9,0e26a9e92e4036bfaa68eb2040a8ec97,Polaris 39-310 5-Liter Zippered Super Bag for ...,Keep your pool water sparkling clean all seaso...,39.47


In [9]:
text_splitter = RecursiveCharacterTextSplitter(
    separators=[".", "\n"],
    chunk_size=500,
    chunk_overlap=0,
    length_function=len,
)

chunked = []

for index, row in df.iterrows():
    product_id = row["product_id"]
    desc = row["description"]
    splits = text_splitter.create_documents([desc])
    for s in splits:
        r = {"product_id": product_id, "content": s.page_content}
        chunked.append(r)

In [11]:
len(chunked)

2669

In [15]:
batch_size = 5
model = TextEmbeddingModel.from_pretrained("textembedding-gecko@001")

def retry_with_backoff(func, *args, retry_delay=5, backoff_factor=2, **kwargs):
    max_attempts = 10
    retries = 0
    for i in range(max_attempts):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            print(f"error: {e}")
            retries += 1
            wait = retry_delay * (backoff_factor**retries)
            print(f"Retry after waiting for {wait} seconds...")
            time.sleep(wait)

for i in tqdm(range(0, len(chunked)-2660, batch_size)):
    request_data = [x["content"] for x in chunked[i : i + batch_size]]
    response = retry_with_backoff(model.get_embeddings, request_data)
    for x, e in zip(chunked[i : i + batch_size], response):
        x["embedding"] = e.values

100%|██████████| 2/2 [00:00<00:00,  2.13it/s]


In [16]:
product_embeddings = pd.DataFrame(chunked[:20])
product_embeddings.head()

Unnamed: 0,product_id,content,embedding
0,7e8697b5b7cdb5a40daf54caf1435cd5,"Rock, paper, scissors is a great way to resolv...","[-0.014531989581882954, -0.01446803379803896, ..."
1,7e8697b5b7cdb5a40daf54caf1435cd5,". Great for educational games, dice games, boa...","[-0.010937819257378578, -0.05220745503902435, ..."
2,7de8b315b3cb91f3680eb5b88a20dcee,Turn any small bicycle into an instrument for ...,"[-0.02734817937016487, -0.02363274060189724, 0..."
3,7de8b315b3cb91f3680eb5b88a20dcee,. Durable Construction: Steel brackets stand u...,"[-0.00025529583217576146, -0.02829601615667343..."
4,7de8b315b3cb91f3680eb5b88a20dcee,. Tools required: Adjustable wrench. www.schwi...,"[-0.012775714509189129, -0.02170153334736824, ..."


In [76]:
async def main():
    loop = asyncio.get_running_loop()
    async with Connector(loop=loop) as connector:

        conn: asyncpg.Connection = await connector.connect_async(
            INSTANCE_CONNECTION_NAME,
            "asyncpg",
            user=DB_USER,
            password=DB_PASS,
            db=DB_NAME
        )

        await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
        await register_vector(conn)

        await conn.execute("DROP TABLE IF EXISTS product_embeddings")

        await conn.execute(
            """CREATE TABLE product_embeddings(
                                product_id VARCHAR(1024) NOT NULL REFERENCES products(product_id),
                                content TEXT,
                                embedding vector(768))"""
        )

        for index, row in tqdm(product_embeddings.iterrows()):
            await conn.execute(
                "INSERT INTO product_embeddings (product_id, content, embedding) VALUES ($1, $2, $3)",
                row["product_id"],
                row["content"],
                np.array(row["embedding"]),
            )

        await conn.close()

In [21]:
comp = np.array(product_embeddings.iloc[0,2])

In [23]:
comp.shape

(768,)

In [77]:
await main()

200it [00:03, 53.52it/s]


In [87]:
toy = "small bicycle"
min_price = 25
max_price = 100 

embedding_query = model.get_embeddings([toy])[0]

In [89]:
embedding_query.values[:5]

[-0.022654814645648003,
 -0.02820597030222416,
 0.03192044794559479,
 0.029981357976794243,
 -0.013982265256345272]

In [103]:
matches = []

async def main():
    loop = asyncio.get_running_loop()
    async with Connector(loop=loop) as connector:

        conn: asyncpg.Connection = await connector.connect_async(
            INSTANCE_CONNECTION_NAME,
            "asyncpg",
            user=DB_USER,
            password=DB_PASS,
            db=DB_NAME
        )

        await register_vector(conn)
        
        similarity_threshold = 0.1
        num_matches = 50

        results = await conn.fetch(
            """
                            WITH vector_matches AS (
                              SELECT product_id, 1 - (embedding <=> $1) AS similarity
                              FROM product_embeddings
                              WHERE 1 - (embedding <=> $1) > $2
                              ORDER BY similarity DESC
                              LIMIT $3
                            )
                            SELECT p.product_name, p.list_price, p.description, v.similarity FROM products p
                            JOIN vector_matches v ON p.product_id = v.product_id 
                            AND list_price >= $4 AND list_price <= $5
                            ORDER BY v.similarity DESC
                            """,
            embedding_query.values,
            similarity_threshold,
            num_matches,
            min_price,
            max_price,
        )

        if len(results) == 0:
            raise Exception("Did not find any results. Adjust the query parameters.")

        for r in results:

            matches.append(
                {
                    "product_name": r["product_name"],
                    "description": r["description"],
                    "similarity": r["similarity"],
                    "list_price": round(r["list_price"], 2),
                }
            )

        await conn.close()

In [104]:
await main()

matches = pd.DataFrame(matches)
matches.head(5)

Unnamed: 0,product_name,description,similarity,list_price
0,"12""-20"" Schwinn Training Wheels",Turn any small bicycle into an instrument for ...,0.639393,28.17
1,"12""-20"" Schwinn Training Wheels",Turn any small bicycle into an instrument for ...,0.637463,28.17
2,"12""-20"" Schwinn Training Wheels",Turn any small bicycle into an instrument for ...,0.622518,28.17
3,Beach Sandy Sand Remover Brush Pack for clean ...,SAND BE GONE! Effortlessly clean sand from fee...,0.621293,25.0
4,"12""-20"" Schwinn Training Wheels",Turn any small bicycle into an instrument for ...,0.608088,28.17


In [None]:
class DatabaseInterface:
    def __init__(self, instance_connection_name, db_user, db_pass, db_name):
        self.instance_connection_name = instance_connection_name
        self.db_user = db_user
        self.db_pass = db_pass
        self.db_name = db_name
        self.connector = Connector()
        self.pool = self.create_pool()

    def get_conn(self):
        conn = self.connector.connect(
            self.instance_connection_name,
            "pg8000",
            user=self.db_user,
            password=self.db_pass,
            db=self.db_name
        )
        return conn

    def create_pool(self):
        return sqlalchemy.create_engine(
            "postgresql+pg8000://",
            creator=self.get_conn,
        )
    
    def run_query(self, query):
        with self.pool.connect() as connection:
            result = connection.execute(text(query))
            return result.fetchall()