# Import

In [1]:
import os
import requests
from typing import List

import chromadb
from chromadb.api.types import Documents, Embeddings
from chromadb.utils.embedding_functions import EmbeddingFunction

import google.generativeai as genai

import gradio as gr
import fitz  # 要安裝PyMuPDF

# Download PDF and Extract text from PDF

In [2]:
def download_pdf(url, save_path):
    """
    從指定 URL 下載 PDF 文件並儲存到本地。

    :param url: PDF 文件的網址 (string)
    :param save_path: PDF 文件儲存的本地路徑 (string)
    """
    # 使用 requests 模組發送 HTTP GET 請求以獲取 PDF 文件
    response = requests.get(url)

    # 打開指定的本地儲存路徑，使用二進位寫入模式 ('wb')
    with open(save_path, 'wb') as f:
        # 將下載的文件內容寫入到本地文件中
        f.write(response.content)


def extract_text_from_pdf_file_obj(file):
    """
    從 PDF 檔案物件提取文本內容。

    :param file: PDF 文件的檔案物件 (e.g., 通過 open(file, 'rb') 獲取)
    :return: 提取的文本內容 (string)
    """
    try:
        with fitz.open(file.name) as doc:
            pdf_text = ""
            for page in doc:
                pdf_text += page.get_text()
        return pdf_text
    except Exception as e:
        return f"Error while reading PDF: {str(e)}"


def extract_text_from_pdf_file_path(file_path):
    """
    從 PDF 文件的路徑提取文本內容。

    :param file_path: PDF 文件的檔案路徑 (string)
    :return: 提取的文本內容 (string)
    """
    try:
        with fitz.open(file_path) as doc:
            pdf_text = ""
            for page in doc:
                pdf_text += page.get_text()
        return pdf_text
    except Exception as e:
        return f"Error while reading PDF: {str(e)}"

# ToDo:
- Text splitting
- ChromaDB
- Prompt Construction

## Implement text splitting function

In [3]:
# 分割文本為小塊
def split_text(text: str, max_chunk_size: int = 500, overlap: int = 20) -> List[str]:
    """
    將長文本分割為多個小塊，支援塊之間的重疊。

    :param text: 要分割的文本 (string)
    :param max_chunk_size: 每個文本塊的最大大小 (int)
    :param overlap: 每個文本塊之間的重疊大小 (int)，最初的設定是 50，目前為 20
    :return: 分割後的文本塊列表 (List of strings)
    """
    # 按照換行符拆分，確保每個菜名不會被拆開
    lines = text.split("\n")

    chunks = []
    # start = 0
    # while start < len(text):
    #     end = min(start + max_chunk_size, len(text))

    #     chunks.append(text[start:end].strip())
    #     start += max_chunk_size - overlap

    current_chunk = ""
    
    for line in lines:
        if len(current_chunk) + len(line) <= max_chunk_size:
            current_chunk += line.strip() + "\n"
        else:
            chunks.append(current_chunk.strip())  # 儲存當前塊
            current_chunk = line.strip() + "\n"  # 開始新的塊
    
    # 加入最後一個 chunk
    if current_chunk:
        chunks.append(current_chunk.strip())

    return chunks

## Custom embedding function using Gemini API

In [4]:
# 自定義 Gemini 嵌入函數
class GeminiEmbeddingFunction(EmbeddingFunction):
    def __init__(self, api_key: str, model: str = "models/embedding-001", title: str = "Restaurant Menu"):
        """
        : param title: 影響 Google Gemini 的嵌入方式，原本為 "Custom query"
        """
        self.api_key = api_key
        self.model = model
        self.title = title
        genai.configure(api_key=self.api_key)

    def __call__(self, input: Documents) -> Embeddings:
        return [
            genai.embed_content(
                model=self.model,
                content=doc,
                task_type="retrieval_document",
                title=self.title
            )["embedding"]
            for doc in input
        ]

## Implement ChromaDB creation and querying

In [5]:
# 向現有的 ChromaDB 集合中新增文件。
def update_chroma_db(client, collection_name: str, new_documents: List[str]):
    """
    向現有的 ChromaDB 集合中新增文件。

    :param path: ChromaDB 的資料庫路徑 (string)
    :param collection_name: 要更新的集合名稱 (string)
    :param new_documents: 要新增的文件列表 (List of strings)
    """

    # Get the existing collection by name
    collection = client.get_or_create_collection(collection_name)

    # Add new documents to the collection
    for i, document in enumerate(new_documents):
        collection.add(
            ids=[f"new_doc_{i}"],  # New unique ID for each document
            documents=[document],  # New document content
        )

    print(
        f"Added {len(new_documents)} new documents to the collection '{collection_name}'.")

In [6]:
# 查詢相關段落
def get_relevant_passage(query: str, db, name: str, n_results: int = 3) -> List[str]:
    """
    從指定的 ChromaDB 集合中查詢與給定問題相關的段落。

    :param query: 用戶的查詢語句 (string)
    :param db: 連接的 ChromaDB 資料庫對象
    :param name: 要查詢的集合名稱 (string)
    :param n_results: 返回的相關結果數量 (int, 默認為 3)
    :ret
    """
    collection = db.get_collection(name)
    results = collection.query(query_texts=[query], n_results=n_results)
    # 加入 if results["documents"] else [] 的部分，防止沒有匹配的內容時產生 IndexError
    return results["documents"][0] #if results["documents"] else []

In [18]:
# 建構提示詞
def make_rag_prompt(query: str, relevant_passages: List[str], chat_history: List[str]) -> str:
    context = "\n\n".join(relevant_passages)
    history = "\n".join(chat_history[-3:])  # 只保留最近 3 輪對話，避免 token 過多

    return f"""
    You are a thoughtful waiter. Use the following conversation history and menu information to assist the customer.

    Previous Conversation:
    {history}

    Menu:
    {context}

    Customer's Question:
    {query}

    IMPORTANT:
    - If the customer is making a decision, **confirm their choice instead of recommending new dishes**.
    - Only recommend a new dish if the customer explicitly asks for suggestions.
    - If the customer asks about "it," infer that "it" refers to the last mentioned dish.
    - Always answer based on the context and prior recommendations.
    - **If the customer explicitly says "NOT" or "another" when asking for a recommendation, exclude the previously mentioned dish from the response.**

    Provide a concise but friendly response.
    """

# LLM Response Generation

In [None]:
# Check Gemini API key
from dotenv import load_dotenv
import os

# 載入 .env 文件中的所有變數
load_dotenv()

# 使用 os.getenv 獲取環境變數
api_key = os.getenv('GEMINI_API_KEY')


# 確認變數是否正確載入
print(f"Gemini api key: {api_key}")

In [9]:
# Generate answer using Gemini Pro API
def generate_answer(prompt: str):
    # load .env
    # load_dotenv()  <-- This code is not necessary
    gemini_api_key = os.getenv('GEMINI_API_KEY')
    if not gemini_api_key:
        raise ValueError(
            "Gemini API Key not provided. Please provide GEMINI_API_KEY as an environment variable")
    genai.configure(api_key=gemini_api_key)
    model = genai.GenerativeModel('gemini-pro')
    result = model.generate_content(prompt)
    return result.text

# Testing 

In [None]:
# Set up configurations
pdf_url = "https://services.google.com/fh/files/misc/ai_adoption_framework_whitepaper.pdf"
pdf_path = "ai_adoption_framework_whitepaper.pdf"

db_folder = "chroma_db"
db_path = os.path.join(os.getcwd(), db_folder)

# Create database directory
if not os.path.exists(db_folder):
    os.makedirs(db_folder)


client = chromadb.PersistentClient(path=db_path)

# a database unit in Chroma is called collection, so db here means collection
db_name = "rag_experiment"
client.get_or_create_collection(db_name)
print(f"{db_name} is created")

In [None]:
# Download and process PDF
download_pdf(pdf_url, pdf_path)
pdf_text = extract_text_from_pdf_file_path(pdf_path)

# Split text into chunks
chunked_text = split_text(pdf_text)

update_chroma_db(client, db_name, chunked_text)

In [None]:
# Process user query
query = 'what are the The AI maturity phases?'
relevant_text = get_relevant_passage(query, client, db_name, n_results=3)

# Generate and display answer
if relevant_text:
    final_prompt = make_rag_prompt(query, "".join(relevant_text))
    answer = generate_answer(final_prompt)
    print("\nGenerated Answer:", answer)
else:
    print("No relevant information found for the given query.")

# Combine Functions 

In [10]:
# 從 PDF 文件提取文本，分割文本為小塊，並更新 ChromaDB 集合。
def add_document_to_db(client, db_name, file):
    """
    :param db_path: ChromaDB 資料庫的路徑 (string)
    :param db_name: 要更新的 ChromaDB 集合名稱 (string)
    :param file: PDF 文件的二進位文件對象 (BinaryIO)
    """
    pdf_text = extract_text_from_pdf_file_obj(file)

    # Split text into chunks
    chunked_text = split_text(pdf_text)

    update_chroma_db(client, db_name, chunked_text)

    print(f"{db_name} is updated")

In [17]:
chat_history = []  # 用來存當前對話記錄

# 基於 RAG (Retrieval-Augmented Generation) 流程生成回答。
def rag_response(query, client, db_name):
    """
    :param query: 用戶的查詢語句 (string)
    :param client: 連接的 ChromaDB 資料庫客戶端
    :param db_name: 查詢的集合名稱 (string)
    :return: 生成的回答或錯誤信息 (string)
    """
    # Process user query
    relevant_text = get_relevant_passage(query, client, db_name, n_results=3)

    # Generate and display answer
    if relevant_text:
        final_prompt = make_rag_prompt(query, "".join(relevant_text), "".join(chat_history))
        answer = generate_answer(final_prompt)
        response = "Your Waiter:\n"+answer

        # 更新 chat_history，保持最近的 3 輪對話
        # 強制記錄推薦的菜品
        if "recommend" in answer.lower():
            chat_history.append(f"Waiter (previous recommendation): {answer}")
        
        chat_history.append(f"User: {query}")
        chat_history.append(f"Your Waiter: {answer}")
        chat_history[:] = chat_history[-6:]  # 確保 chat_history 不會過長

    else:
        response = "No relevant information found for the given query."

    return response

# Main execution
## ToDo:
 - Chat history
 - Multiple file injest

# Initilaize 

In [12]:
# 初始化 ChromaDB 資料庫，創建資料庫目錄並設置集合。
def initialize_database(db_folder: str, db_name: str) -> chromadb.PersistentClient:
    """
    :param db_folder: 資料庫文件夾名稱 (string)
    :param db_name: 資料庫集合名稱 (string)
    :return: 已初始化的 ChromaDB 客戶端 (chromadb.PersistentClient)
    """
    # 獲取當前工作目錄，構建完整的資料庫路徑
    db_path = os.path.join(os.getcwd(), db_folder)

    # 如果資料庫目錄不存在，則創建該目錄
    if not os.path.exists(db_folder):
        os.makedirs(db_folder)

    # 創建一個 PersistentClient 連接到指定的資料庫路徑
    client = chromadb.PersistentClient(path=db_path)

    # 在資料庫中創建或獲取指定名稱的集合
    client.get_or_create_collection(db_name)

    # 打印提示信息，確認集合已創建或存在
    print(f"Collection '{db_name}' is initialized in {db_folder}.")

    # 返回已初始化的客戶端對象
    return client

In [None]:
# 設定參數
db_folder = "chroma_db"
db_name = "rag_experiment"

client = initialize_database(db_folder, db_name)
print(client)

# gradio UI

In [None]:
# DEMO 介面


# 初始化聊天歷史
chat_history = []  # 用於存儲用戶和機器人之間的所有對話

# 定義用戶輸入的交互邏輯


def respond(input_text, history):
    """
    處理用戶輸入，生成回應並更新聊天歷史。
    Args:
        input_text (str): 用戶的輸入訊息。
        history (list): 聊天歷史記錄。
    Returns:
        tuple: 清空的輸入框和更新後的聊天歷史。
    """
    # 確保聊天歷史初始化為空列表
    if history is None:
        history = []

    # 使用 RAG 模型生成回應
    bot_response = rag_response(input_text, client, db_name)

    # 將用戶輸入和機器人回應追加到歷史記錄
    history.append([input_text, bot_response])  # 每次對話為 [用戶訊息, 機器人回應]

    return "", history  # 返回清空的輸入框和新的聊天歷史

# 處理 PDF 文件上傳的函數


def handle_pdf_upload(file):
    """
    處理用戶上傳的 PDF 文件。
    Args:
        file (File): 上傳的文件對象。
    Returns:
        str: 文件處理狀態信息。
    """
    if file is None:
        return "尚未上傳文件。"

    # 檢查文件格式是否為 PDF
    if not file.name.endswith(".pdf"):
        return "僅支持上傳 PDF 文件！"

    # 模擬將文件添加到數據庫
    add_document_to_db(client, db_name, file)
    return f"已上傳文件：{file.name}"


# 定義 Gradio 接口
with gr.Blocks() as demo:
    gr.Markdown("## 智能點餐小幫手")

    # 聊天框
    chatbot = gr.Chatbot()  # 用於顯示聊天對話

    # 用戶輸入框
    user_input = gr.Textbox(placeholder="請輸入你的消息...", label="輸入")

    # 按鈕
    clear = gr.Button("清除聊天")

    # 文件上傳框和狀態顯示
    file_upload = gr.File(label="上傳 PDF 文件", file_types=[
                          ".pdf"])  # 限制上傳文件為 PDF
    file_status = gr.Textbox(label="文件狀態", interactive=False)  # 顯示文件處理結果

    # 定義交互邏輯
    file_upload.change(handle_pdf_upload, file_upload, file_status)  # 處理文件上傳
    user_input.submit(respond, [user_input, chatbot], [
                      user_input, chatbot])  # 處理聊天輸入
    clear.click(lambda: [], None, chatbot)  # 點擊清除按鈕清空聊天記錄

In [86]:
# 放這邊是因為launch會很難按到最下面的block
demo.close()

In [None]:
demo.launch(share=True)

In [None]:
client = initialize_database("chroma_db", "rag_experiment")
collection = client.get_or_create_collection("rag_experiment")
print("Database initialized successfully!")