In [91]:
# Data processing
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pandas as pd
import os

# DB connection
from google.cloud.sql.connector import Connector
import sqlalchemy
from sqlalchemy import create_engine, text

# Embeddings
from vertexai.language_models import TextEmbeddingModel

# Utils
import time
from tqdm import tqdm
import numpy as np

In [3]:
class DataProcessor:
    def __init__(self, df, text_splitter):
        self.df = df
        self.text_splitter = text_splitter
        self.chunked = []
        
    def process(self):
        for index, row in self.df.iterrows():
            product_id = row["product_id"]
            desc = row["description"]
            splits = self.text_splitter.create_documents([desc])
            for s in splits:
                r = {"product_id": product_id, "content": s.page_content}
                self.chunked.append(r)
        return self.chunked

In [32]:
class EmbeddingGenerator:
    def __init__(self, model, batch_size):
        self.model = model
        self.batch_size = batch_size
        
    def generate_embeddings(self, chunked):
        for i in tqdm(range(0, len(chunked), self.batch_size)):
            request_data = [c["content"] for c in chunked[i: i + self.batch_size]]
            response = self.retry_with_backoff(self.model.get_embeddings, request_data)
            for c, e in zip(chunked[i: i + self.batch_size], response):
                c["embedding"] = e.values
                
    def retry_with_backoff(self, func, *args, retry_delay=5, backoff_factor=2):
        max_attempts = 10
        retries = 0
        for i in range(max_attempts):
            try:
                return func(*args)
            except Exception as e:
                print(f"error: {e}")
                retries += 1
                wait = retry_delay * (backoff_factor**retries)
                print(f"Retry after waiting for {wait} seconds...")
                time.sleep(wait)

In [129]:
class DatabaseInterface:
    def __init__(self, instance_connection_name, db_user, db_pass, db_name):
        self.instance_connection_name = instance_connection_name
        self.db_user = db_user
        self.db_pass = db_pass
        self.db_name = db_name
        self.connector = Connector()
        self.pool = self.create_pool()

    def get_conn(self):
        conn = self.connector.connect(
            self.instance_connection_name,
            "pg8000",
            user=self.db_user,
            password=self.db_pass,
            db=self.db_name
        )
        return conn

    def create_pool(self):
        return create_engine(
            "postgresql+pg8000://",
            creator=self.get_conn,
        )
    
    def run_query(self, query, fetch=True):
        with self.pool.connect() as connection:
            result = connection.execute(text(query))
            connection.commit() 
            if fetch:
                return result.fetchall()
            else:
                return None
        
    def create_table(self, table_name, columns):
        """
        Creates a table with the given name and columns.
        :param table_name: The name of the table
        :param columns: A dictionary where keys are column names and values are SQL data types
        """
        cols = ', '.join(f'{col} {dtype}' for col, dtype in columns.items())
        create_table_query = f'CREATE TABLE {table_name} ({cols});'
        self.run_query(create_table_query, fetch=False)

    def drop_table(self, table_name):
        """
        Drops the table with the given name.
        :param table_name: The name of the table
        """
        drop_table_query = f'DROP TABLE {table_name};'
        print(self.run_query(drop_table_query, fetch=False))

### 1 - Raw Data Processing

In [9]:
DATASET_URL = "https://github.com/GoogleCloudPlatform/python-docs-samples/raw/main/cloud-sql/postgres/pgvector/data/retail_toy_dataset.csv"

df = pd.read_csv(DATASET_URL)
df = df.loc[:, ["product_id", "product_name", "description", "list_price"]]
df = df.dropna()

df.head(3)

Unnamed: 0,product_id,product_name,description,list_price
0,7e8697b5b7cdb5a40daf54caf1435cd5,"Koplow Games Set of 2 D12 12-Sided Rock, Paper...","Rock, paper, scissors is a great way to resolv...",3.56
1,7de8b315b3cb91f3680eb5b88a20dcee,"12""-20"" Schwinn Training Wheels",Turn any small bicycle into an instrument for ...,28.17
2,fb9535c103d7d717f0414b2b111cfaaa,Bicycle Pinochle Jumbo Index Playing Cards - 1...,Purchase includes 1 blue deck and 1 red deck. ...,6.49


In [14]:
text_splitter = RecursiveCharacterTextSplitter(
    separators=[".", "\n"],
    chunk_size=500,
    chunk_overlap=0,
    length_function=len,
)

data_processor = DataProcessor(df, text_splitter)
chunked_data = data_processor.process()[:200]

### 2 - Computation of Embeddings

In [33]:
batch_size = 5
model = TextEmbeddingModel.from_pretrained("textembedding-gecko@001")

embedding_model = EmbeddingGenerator(model, batch_size)
embedding_model.generate_embeddings(chunked_data)

100%|██████████| 40/40 [00:12<00:00,  3.29it/s]


### 3 - Loading Data to DB

In [80]:
project_id = "steam-378309"
region = "europe-west3"
instance_name = "legalm"

DB_NAME = "retail"
DB_USER = "postgres"
DB_PASS = "bestday13"
DB_PORT = "5432"

INSTANCE_CONNECTION_NAME = f"{project_id}:{region}:{instance_name}"
print(f"Your instance connection name is: {INSTANCE_CONNECTION_NAME}")

Your instance connection name is: steam-378309:europe-west3:legalm


In [130]:
db_interface = DatabaseInterface(INSTANCE_CONNECTION_NAME, DB_USER, DB_PASS, DB_NAME)

In [138]:
product_fields = {"product_id": "VARCHAR(1024)",
                  "product_name": "TEXT",
                  "description": "TEXT",
                  "list_price": "NUMERIC"}

product_embedding_fields = {"product_id": "VARCHAR(1024)",
                            "content": "TEXT",
                            "embedding": "VECTOR(768)"}

In [132]:
db_interface.drop_table("product_embeddings")

None


In [115]:
query = """
        SELECT tablename 
        FROM pg_catalog.pg_tables 
        WHERE schemaname != 'pg_catalog' 
        AND schemaname != 'information_schema';
        """

In [139]:
db_interface.create_table("product_embeddings", product_embedding_fields)

In [140]:
db_interface.run_query(query)

[('products',), ('product_embeddings',)]

In [None]:
class DatabaseInterface:
    def __init__(self, instance_connection_name, db_user, db_pass, db_name):
        self.instance_connection_name = instance_connection_name
        self.db_user = db_user
        self.db_pass = db_pass
        self.db_name = db_name
        self.connector = Connector()
        self.pool = self.create_pool()

    def get_conn(self):
        conn = self.connector.connect(
            self.instance_connection_name,
            "pg8000",
            user=self.db_user,
            password=self.db_pass,
            db=self.db_name
        )
        return conn

    def create_pool(self):
        return create_engine(
            "postgresql+pg8000://",
            creator=self.get_conn,
        )
    
    def run_query(self, query):
        with self.pool.connect() as connection:
            result = connection.execute(text(query))
            return result.fetchall()
        
    def create_table(self, table_name, columns):
        """
        Creates a table with the given name and columns.
        :param table_name: The name of the table
        :param columns: A dictionary where keys are column names and values are SQL data types
        """
        cols = ', '.join(f'{col} {dtype}' for col, dtype in columns.items())
        create_table_query = f'CREATE TABLE {table_name} ({cols});'
        self.run_query(create_table_query)

    def drop_table(self, table_name):
        """
        Drops the table with the given name.
        :param table_name: The name of the table
        """
        drop_table_query = f'DROP TABLE {table_name};'
        self.run_query(drop_table_query)

In [None]:
async def main():
    # Initialize the objects
    # text_splitter = RecursiveCharacterTextSplitter(...) 
    # data_processor = DataProcessor(df, text_splitter)
    # embedding_model = EmbeddingGenerator(...)
    db_connector = DatabaseConnector(...)
    
    # Process the data
    # chunked_data = data_processor.process()
    
    # Generate embeddings
    embedding_model.generate_embeddings(chunked_data)
    
    # Connect to the database
    conn = await db_connector.connect()
    
    # Database Operations
    db_operations = DatabaseOperations(conn)
    await db_operations.create_table()
    await db_operations.insert_data(chunked_data)
    results = await db_operations.query_data(...)
    
    await conn.close()

asyncio.run(main())
