In [9]:
PROJECT_ID = "imrenagi-gemini-experiment" 
LOCATION = "us-central1" 
GEMINI_EMBEDDING_MODEL = "text-embedding-004"

import vertexai
vertexai.init(project=PROJECT_ID, location=LOCATION)

from langchain_google_vertexai import VertexAIEmbeddings
embeddings_service = VertexAIEmbeddings(model_name=GEMINI_EMBEDDING_MODEL)

In [10]:
%%writefile lib/pg_retriever.py

from typing import List

from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever

from langchain_google_vertexai import VertexAIEmbeddings

import psycopg2
from pgvector.psycopg2 import register_vector

class CourseContentRetriever(BaseRetriever):
    """Retriever to find relevant course content based on the
    query provided."""

    embeddings_service: VertexAIEmbeddings    
    similarity_threshold: float
    num_matches: int
    conn_str: str

    def _get_relevant_documents(
            self, query: str, *, run_manager: CallbackManagerForRetrieverRun
        ) -> List[Document]:
        conn = psycopg2.connect(self.conn_str)
        register_vector(conn)

        qe = self.embeddings_service.embed_query(query)

        with conn.cursor() as cur:
            cur.execute(
                """
                        WITH vector_matches AS (
                        SELECT id, content, 1 - (embedding <=> %s::vector) AS similarity
                        FROM course_content_embeddings
                        WHERE 1 - (embedding <=> %s::vector) > %s
                        ORDER BY similarity DESC
                        LIMIT %s
                        )
                        SELECT cc.id as id, cc.title as title, 
                            vm.content as content, 
                            vm.similarity as similarity 
                        FROM course_contents cc
                        LEFT JOIN vector_matches vm ON cc.id = vm.id;
                """,
                (qe, qe, self.similarity_threshold, self.num_matches)
            )
            results = cur.fetchall()

        conn.close()

        if not results:
            return []
        
        return [
            Document(
                page_content=r[2],
                metadata={
                    "id": r[0],
                    "title": r[1],
                    "similarity": r[3],
                }
            ) for r in results if r[2] is not None
        ]

Overwriting lib/pg_retriever.py


In [11]:
from lib.pg_retriever import CourseContentRetriever

db_conn_string = "postgres://pyconapac:pyconapac@localhost:5432/pyconapac"

retriever = CourseContentRetriever(embeddings_service=embeddings_service, conn_str=db_conn_string, similarity_threshold=0.1, num_matches=10)
retriever.invoke("what is strategy for creating forgot password", run_manager=None)

[Document(metadata={'id': 2, 'title': 'Forgot Password Cheat Sheet', 'similarity': 0.626802716961831}, page_content="1. Generate a token to the user and attach it in the URL query string.\n2. Send this token to the user via email.\n   - Don't rely on the [Host](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host) header while creating the reset URLs to avoid [Host Header Injection](https://owasp.org/www-project-web-security-testing-guide/stable/4-Web_Application_Security_Testing/07-Input_Validation_Testing/17-Testing_for_Host_Header_Injection) attacks. The URL should be either be hard-coded, or should be validated against a list of trusted domains.\n   - Ensure that the URL is using HTTPS.\n3. The user receives the email, and browses to the URL with the attached token.\n   - Ensure that the reset password page adds the [Referrer Policy](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy) tag with the `noreferrer` value in order to avoid [referrer leaka