In [1]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.document_loaders import PyPDFLoader
import nltk
from langchain_text_splitters import NLTKTextSplitter
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from IPython.display import Markdown as md
from dotenv import load_dotenv
import os

In [2]:
load_dotenv()  
key = os.getenv("GOOGLE_API_KEY")

In [4]:
chat_model = ChatGoogleGenerativeAI(google_api_key=key, 
                                   model="gemini-1.5-flash-latest")
loader = PyPDFLoader("data/4.SQL.pdf")
pages = loader.load_and_split()
pages

[Document(metadata={'producer': '適用於 Microsoft 365 的 Microsoft® PowerPoint®', 'creator': '適用於 Microsoft 365 的 Microsoft® PowerPoint®', 'creationdate': '2024-12-13T17:06:41+08:00', 'title': 'Chapter 4:  SQL', 'author': 'Marilyn Turnamian', 'moddate': '2024-12-13T17:06:41+08:00', 'source': 'data/4.SQL.pdf', 'total_pages': 60, 'page': 0, 'page_label': '1'}, page_content='3.1\nSQL'),
 Document(metadata={'producer': '適用於 Microsoft 365 的 Microsoft® PowerPoint®', 'creator': '適用於 Microsoft 365 的 Microsoft® PowerPoint®', 'creationdate': '2024-12-13T17:06:41+08:00', 'title': 'Chapter 4:  SQL', 'author': 'Marilyn Turnamian', 'moddate': '2024-12-13T17:06:41+08:00', 'source': 'data/4.SQL.pdf', 'total_pages': 60, 'page': 1, 'page_label': '2'}, page_content='3.2\nSQL\nBasic Query Structure\nSet Operations\nAggregate Functions\nNull Values\nNested Subqueries\nComplex Queries \nViews\nModification of the Database\nJoined Relations** \nData Definition Language'),
 Document(metadata={'producer': '適用於 Mic

In [5]:
text_splitter = NLTKTextSplitter(chunk_size=500, chunk_overlap=100)

chunks = text_splitter.split_documents(pages)
print(len(chunks))
print(type(chunks[0]))

83
<class 'langchain_core.documents.base.Document'>


In [6]:
embedding_model = GoogleGenerativeAIEmbeddings(google_api_key=key, model="models/embedding-001")
db = Chroma.from_documents(chunks, embedding_model, persist_directory="chroma_db_")
db.persist()
db_connection = Chroma(persist_directory="chroma_db_", embedding_function=embedding_model)

  db.persist()
  db_connection = Chroma(persist_directory="chroma_db_", embedding_function=embedding_model)


In [7]:
retriever = db_connection.as_retriever(search_kwargs={"k": 5})

print(type(retriever))

<class 'langchain_core.vectorstores.base.VectorStoreRetriever'>


In [8]:
chat_template = ChatPromptTemplate.from_messages([
    SystemMessage(content="""You are a teacher in Scaffolding Instruction education.
                  Given a context and question from user,
                  you should answer based on the given context."""),
    HumanMessagePromptTemplate.from_template("""Answer the question based on the given context.
    Context: {context}
    Question: {question}
    Answer: """)
])

output_parser = StrOutputParser()


def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)


rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | chat_template
    | chat_model
    | output_parser
)

In [9]:
response = rag_chain.invoke("""how to Find the names of all branches that have greater
assets than all branches located in Brooklyn in SQL""")
md(response)

```sql
select branch_name
from branch
where assets > all (select assets from branch where branch_city = 'Brooklyn');
```