In [1]:
# Embedding creation
import torch
from transformers import AutoTokenizer, AutoModel

# DB connection
from google.cloud.sql.connector import Connector
import sqlalchemy
from sqlalchemy import create_engine, text

# Utils
from tqdm import tqdm

# Data processing
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pandas as pd
import numpy as np
import os

In [2]:
class DatabaseInterface:
    def __init__(self, instance_connection_name, db_user, db_pass, db_name):
        self.instance_connection_name = instance_connection_name
        self.db_user = db_user
        self.db_pass = db_pass
        self.db_name = db_name
        self.connector = Connector()
        self.pool = self.create_pool()

    def get_conn(self):
        conn = self.connector.connect(
            self.instance_connection_name,
            "pg8000",
            user=self.db_user,
            password=self.db_pass,
            db=self.db_name
        )
        return conn

    def create_pool(self):
        return create_engine(
            "postgresql+pg8000://",
            creator=self.get_conn,
        )
    
    def run_query(self, query, fetch=True):
        with self.pool.connect() as connection:
            try:
                result = connection.execute(text(query))
                connection.commit() 
                if fetch:
                    return result.fetchall()
                else:
                    return None
            except Exception as e:
                    print("EXCEPTION THROWN")
                    print(e)
                    connection.rollback()  
            
    def insert_data_from_dataframe(self, dataframe, table_name):
            try:
                dataframe.to_sql(
                    table_name,
                    self.pool,
                    if_exists='append',
                    index=False,
                    method='multi'
                )
            except Exception as e:
                print("EXCEPTION THROWN DURING INSERT")
                print(e)
        
    def create_table(self, table_name, columns):
        """
        Creates a table with the given name and columns.
        :param table_name: The name of the table
        :param columns: A dictionary where keys are column names and values are SQL data types
        """
        cols = ', '.join(f'{col} {dtype}' for col, dtype in columns.items())
        create_table_query = f'CREATE TABLE {table_name} ({cols});'
        self.run_query(create_table_query, fetch=False)

    def drop_table(self, table_name):
        """
        Drops the table with the given name.
        :param table_name: The name of the table
        """
        drop_table_query = f'DROP TABLE {table_name};'
        print(self.run_query(drop_table_query, fetch=False))

In [1]:
project_id = os.environ.get("PROJECT_ID")
region = "europe-west3"
instance_name = "legalm"

DB_NAME = "pubmed"
DB_USER = "postgres"
DB_PASS = os.environ.get("DB_PASS")
DB_PORT = "5432"

INSTANCE_CONNECTION_NAME = f"{project_id}:{region}:{instance_name}"
print(f"Your instance connection name is: {INSTANCE_CONNECTION_NAME}")

NameError: name 'os' is not defined

In [4]:
db_interface = DatabaseInterface(INSTANCE_CONNECTION_NAME, DB_USER, DB_PASS, DB_NAME)

In [5]:
table_query = """
        SELECT tablename 
        FROM pg_catalog.pg_tables 
        WHERE schemaname != 'pg_catalog' 
        AND schemaname != 'information_schema';
        """

In [6]:
db_interface.run_query(table_query)

[('pm_abstracts',), ('pm_abstracts_embeddings',)]

In [10]:
abstracts_query = """
                  SELECT * FROM pm_abstracts;
                  """

In [11]:
all_abstracts = db_interface.run_query(abstracts_query)

In [12]:
text_splitter = RecursiveCharacterTextSplitter(
    separators=[".", "\n"],
    chunk_size=500,
    chunk_overlap=0,
    length_function=len,
)

In [13]:
chunked = []

# Iterate over each tuple in the list
for abstract_tuple in tqdm(all_abstracts):
    pmid, title, abstract, _ = abstract_tuple  # Unpack the tuple
    
    splits = text_splitter.create_documents([abstract])
    for s in splits:
        r = {"pmid": pmid, "title": title, "abstract": s.page_content}
        chunked.append(r)

100%|██████████| 17076/17076 [00:01<00:00, 12645.14it/s]


In [15]:
model_name = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

In [16]:
# Extract all abstracts
all_abstracts = [item["abstract"] for item in chunked]

# Tokenize all abstracts at once
all_tokenized_inputs = tokenizer(all_abstracts, return_tensors="pt", padding=True, truncation=True)

In [19]:
# Initialize an empty list to store the data
data = []
batch_size = 64
decrease_by = 69000

# Loop through the tokenized inputs in batches
for i in tqdm(range(0, len(chunked) - decrease_by, batch_size)):
    batch_abstracts_dicts = chunked[i:i+batch_size]
    
    # Extract the relevant slices from the pre-tokenized inputs
    batch_input_ids = all_tokenized_inputs['input_ids'][i:i+batch_size]
    batch_attention_mask = all_tokenized_inputs['attention_mask'][i:i+batch_size]
    
    # Generate embeddings for the encoded text
    with torch.no_grad():
        outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
        batch_embeddings = outputs.last_hidden_state[:, 0, :].numpy()

    # Append the data to the list
    for j, embedding in enumerate(batch_embeddings):
        pmid = batch_abstracts_dicts[j]["pmid"]
        title = batch_abstracts_dicts[j]["title"]
        abstract = batch_abstracts_dicts[j]["abstract"]
        data.append([pmid, title, abstract, embedding])

100%|██████████| 30/30 [10:29<00:00, 21.00s/it]


In [20]:
df = pd.DataFrame(data, columns=['pmid', 'title', 'abstract', 'embedding'])
# df['abstract'] = df['abstract'].apply(lambda x: np.array(x))

In [21]:
df.shape

(1920, 4)

In [22]:
db_interface.run_query("select * from pm_abstracts_embeddings")

[]

In [23]:
df.dtypes

pmid          int64
title        object
abstract     object
embedding    object
dtype: object

In [24]:
# df.to_csv('4096_pubmed.csv', index=False)

In [25]:
db_interface.run_query("TRUNCATE pm_abstracts_embeddings", fetch=False)

In [26]:
orig = df.iloc[0,3]

In [27]:
orig.shape

(768,)

In [28]:
type(orig)

numpy.ndarray

In [126]:
insert_query = f"""
               INSERT INTO pm_abstracts_embeddings (pmid, title, abstract, embedding) 
               VALUES (123, 'Tit', 'abs', {orig})
               """

In [128]:
db_interface.run_query("CREATE EXTENSION IF NOT EXISTS vector")

EXCEPTION THROWN
This result object does not return rows. It has been closed automatically.


In [131]:
DATASET_URL = "https://github.com/GoogleCloudPlatform/python-docs-samples/raw/main/cloud-sql/postgres/pgvector/data/retail_toy_dataset.csv"

df2 = pd.read_csv(DATASET_URL)
df2 = df2.loc[:, ["product_id", "product_name", "description", "list_price"]]
df2 = df2.dropna()

df2.head(10)

Unnamed: 0,product_id,product_name,description,list_price
0,7e8697b5b7cdb5a40daf54caf1435cd5,"Koplow Games Set of 2 D12 12-Sided Rock, Paper...","Rock, paper, scissors is a great way to resolv...",3.56
1,7de8b315b3cb91f3680eb5b88a20dcee,"12""-20"" Schwinn Training Wheels",Turn any small bicycle into an instrument for ...,28.17
2,fb9535c103d7d717f0414b2b111cfaaa,Bicycle Pinochle Jumbo Index Playing Cards - 1...,Purchase includes 1 blue deck and 1 red deck. ...,6.49
3,c73ea622b3be6a3ffa3b0b5490e4929e,Step2 Woodland Adventure Playhouse & Slide,The Step2 Woodland Climber Adventure Playhouse...,499.99
4,dec7bd1f983887650715c6fafaa5b593,Step2 Naturally Playful Welcome Home Playhouse...,Children can play and explore in the Step2 Nat...,600.0
5,74a695e3675efc2aad11ed73c46db29b,Slip N Slide Triple Racer with Slide Boogies,Triple Racer Slip and Slide with Boogie Boards...,37.21
6,3eae5293b56c25f63b47cb8a89fb4813,Hydro Tools Digital Pool/Spa Thermometer,The solar-powered Swimline Floating Digital Th...,15.92
7,ed85bf829a36c67042503ffd9b6ab475,Full Bucket Swing With Coated Chain Toddler Sw...,Safe Kids&Children Full Bucket Swing With Coa...,102.26
8,55820fa53f0583cb637d5cb2b051d78c,Banzai Water Park Splash Zone,Dive into fun in your own backyard with the B...,397.82
9,0e26a9e92e4036bfaa68eb2040a8ec97,Polaris 39-310 5-Liter Zippered Super Bag for ...,Keep your pool water sparkling clean all seaso...,39.47


In [133]:
text_splitter = RecursiveCharacterTextSplitter(
    separators=[".", "\n"],
    chunk_size=500,
    chunk_overlap=0,
    length_function=len,
)

chunked = []

for index, row in df2.iterrows():
    product_id = row["product_id"]
    desc = row["description"]
    splits = text_splitter.create_documents([desc])
    for s in splits:
        r = {"product_id": product_id, "content": s.page_content}
        chunked.append(r)

In [134]:
len(chunked)

2669

In [136]:
from vertexai.language_models import TextEmbeddingModel

In [137]:
batch_size = 5
model = TextEmbeddingModel.from_pretrained("textembedding-gecko@001")

def retry_with_backoff(func, *args, retry_delay=5, backoff_factor=2, **kwargs):
    max_attempts = 10
    retries = 0
    for i in range(max_attempts):
        try:
            return func(*args, **kwargs)
        except Exception as e:
            print(f"error: {e}")
            retries += 1
            wait = retry_delay * (backoff_factor**retries)
            print(f"Retry after waiting for {wait} seconds...")
            time.sleep(wait)

for i in tqdm(range(0, len(chunked)-2660, batch_size)):
    request_data = [x["content"] for x in chunked[i : i + batch_size]]
    response = retry_with_backoff(model.get_embeddings, request_data)
    for x, e in zip(chunked[i : i + batch_size], response):
        x["embedding"] = e.values

100%|██████████| 2/2 [00:00<00:00,  2.06it/s]


In [179]:
d = pd.read_csv("4096_pubmed.csv")

In [158]:
row = chunked[0]
embedding = np.array(row["embedding"])
product_id = row["product_id"]

e = embedding.tolist()

In [159]:
ni_query = f"INSERT INTO products (product_id, embedding) VALUES ('{product_id}', ARRAY{e})"

In [45]:
db_interface.run_query("TRUNCATE pm_abstracts_embeddings", fetch=False)

In [46]:
for i in tqdm(range(0, len(df))):
    r = df.iloc[i]
    
    # Define the SQL template with placeholders
    insert_query = """
        INSERT INTO pm_abstracts_embeddings (pmid, title, abstract, embedding) 
        VALUES (:pmid, :title, :abstract, :embedding)"""

    # Create a dictionary of the values to insert
    values_to_insert = {
        "pmid": r["pmid"],
        "title": r["title"],
        "abstract": r["abstract"],
        "embedding": str(r["embedding"].tolist())
    }

    # Execute the parameterized query
    try:
        with db_interface.pool.connect() as connection:
            connection.execute(text(insert_query), values_to_insert)
            connection.commit()
    except Exception as e:
        print("EXCEPTION THROWN")
        print(e)
        connection.rollback()

100%|██████████| 1920/1920 [02:36<00:00, 12.27it/s]


In [1]:
db_interface.run_query("SELECT * FROM pm_abstracts_embeddings LIMIT 10")

NameError: name 'db_interface' is not defined

In [75]:
matches = []

similarity_threshold = 0.001
num_matches = 50

user_query_embedding = str(df.iloc[0]["embedding"].tolist())

sim_query_values = {
                    "user_query_embedding": user_query_embedding,
                    "similarity_threshold": similarity_threshold,
                    "num_matches": num_matches
                    }

sim_query = """SELECT pmid, 1 - (embedding <=> :user_query_embedding) AS similarity
               FROM pm_abstracts_embeddings
               WHERE 1 - (embedding <=> :user_query_embedding) > :similarity_threshold
               ORDER BY similarity DESC
               LIMIT :num_matches
            """

In [76]:
results = []

try:
    with db_interface.pool.connect() as connection:
        cursor = connection.execute(text(sim_query), sim_query_values)
        results = cursor.fetchall()
        connection.commit()
except Exception as e:
    print("EXCEPTION THROWN")
    print(e)
    connection.rollback()
        
if len(results) == 0:
    raise Exception("Did not find any results. Adjust the query parameters.")
    
print(results)

[(36434554, 1.0), (36434608, 0.9044531030761114), (36434749, 0.8958483012400144), (36434719, 0.8928046178408402), (36434749, 0.8859170543228674), (36434830, 0.8841930280460857), (36434770, 0.8833488838780184), (36434616, 0.8831285908356069), (36434938, 0.8811583147877649), (36434839, 0.8805230177790722), (36434667, 0.878112647238501), (36435391, 0.8756767229825894), (36434591, 0.8752964718876801), (36434955, 0.8746033887535118), (36434729, 0.8741800456871791), (36434667, 0.8712848973639666), (36434835, 0.8709204070308273), (36434706, 0.870667826012837), (36434938, 0.8695244731661221), (36434943, 0.8685940810484042), (36434990, 0.8679936489331247), (36434674, 0.8660701011103776), (36434679, 0.8642882054951763), (36435057, 0.8639257081413138), (36434667, 0.8601994747306938), (36434689, 0.8591346423946651), (36435088, 0.8590170164448878), (36434591, 0.857740304148317), (36434602, 0.8564888007690586), (36434574, 0.8552545837595658), (36434976, 0.8552462878318501), (36434741, 0.854226953161