In [1]:
import os
import json
import logging
from pathlib import Path
from typing import List, Dict, Optional, Tuple
from IPython.display import Image, display
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
from vertexai.generative_models import GenerativeModel, Part, Image as VertexImage
import vertexai
import PyPDF2
from PIL import Image as PILImage

# 設置日誌
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 檢查並創建必要的目錄
Path('embeddings').mkdir(exist_ok=True)
Path('image').mkdir(exist_ok=True)

  from tqdm.autonotebook import tqdm, trange


## 2. 資料處理和 Embedding 生成

In [2]:
class DataProcessor:
    def __init__(self, text_model_name: str = 'all-MiniLM-L6-v2'):
        self.text_model = SentenceTransformer(text_model_name)
        
    def process_pdf(self, pdf_path: str) -> List[Dict]:
        """處理 PDF 文件"""
        logger.info(f"Processing PDF: {pdf_path}")
        documents = []
        
        try:
            with open(pdf_path, 'rb') as file:
                pdf_reader = PyPDF2.PdfReader(file)
                for page_num, page in enumerate(pdf_reader.pages):
                    text = page.extract_text()
                    paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
                    
                    for para in paragraphs:
                        documents.append({
                            'content': para,
                            'metadata': {
                                'source': pdf_path,
                                'page': page_num + 1,
                                'type': 'pdf'
                            }
                        })
            
            logger.info(f"Extracted {len(documents)} paragraphs from PDF")
            return documents
        except Exception as e:
            logger.error(f"Error processing PDF {pdf_path}: {str(e)}")
            return []

    def process_csv(self, csv_path: str, image_dir: str = "image") -> List[Dict]:
        """處理 CSV 檔案和相關圖片"""
        logger.info(f"Processing CSV: {csv_path}")
        df = pd.read_csv(csv_path)
        posts = []
        current_post = None
        
        for _, row in df.iterrows():
            if pd.notna(row['post']):
                if current_post:
                    posts.append({
                        'content': current_post,
                        'metadata': {'type': 'social_post'}
                    })
                current_post = {
                    'post': row['post'],
                    'responses': [],
                    'images': [],
                    'link': row['link'] if pd.notna(row['link']) else None
                }
            
            if pd.notna(row['responses']):
                current_post['responses'].append(row['responses'])
            
            # 處理圖片
            if pd.notna(row.get('images')):
                image_path = Path(image_dir) / row['images']
                if image_path.exists():
                    current_post['images'].append(row['images'])
                    logger.info(f"Added image: {row['images']}")
        
        if current_post:
            posts.append({
                'content': current_post,
                'metadata': {'type': 'social_post'}
            })
        
        logger.info(f"Processed {len(posts)} posts from CSV")
        return posts

    def create_embeddings(self, documents: List[Dict]) -> List[Dict]:
        """為文件生成 embeddings"""
        embedded_docs = []
        
        for doc in documents:
            if doc['metadata']['type'] == 'pdf':
                text = doc['content']
            else:  # social_post
                content = doc['content']
                text = f"{content['post']} {' '.join(content['responses'])}"
            
            embedding = self.text_model.encode(text)
            
            embedded_docs.append({
                'text_embedding': embedding.tolist(),
                'content': doc['content'],
                'metadata': doc['metadata']
            })
        
        return embedded_docs

    def process_and_save(self, 
                        csv_path: str, 
                        pdf_path: str,
                        save_path: str = 'embeddings/embeddings.json'):
        """處理所有資料並保存 embeddings"""
        # 處理 PDF
        pdf_docs = self.process_pdf(pdf_path)
        
        # 處理 CSV
        csv_docs = self.process_csv(csv_path)
        
        # 合併文件
        all_docs = pdf_docs + csv_docs
        
        # 生成 embeddings
        embedded_docs = self.create_embeddings(all_docs)
        
        # 保存結果
        with open(save_path, 'w', encoding='utf-8') as f:
            json.dump(embedded_docs, f, ensure_ascii=False, indent=2)
        
        logger.info(f"Saved {len(embedded_docs)} embedded documents to {save_path}")
        return embedded_docs

## 3. RAG 系統實作

In [3]:
class MultiModalRAGSystem:
    def __init__(self, 
                 project_id: str,
                 location: str = "us-central1",
                 image_dir: str = "image"):
        
        self.image_dir = Path(image_dir)
        
        # 設定 Google Cloud 憑證
        if 'GOOGLE_APPLICATION_CREDENTIALS' not in os.environ:
            os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = 'key/gemini-key.json'
            
        # 初始化 Vertex AI
        vertexai.init(project=project_id, location=location)
        self.llm = GenerativeModel('gemini-1.0-pro-vision-001')
        
        # 載入 embeddings
        self.load_saved_embeddings()
        
    def load_saved_embeddings(self, embeddings_path: str = 'embeddings/embeddings.json'):
        """載入已保存的embeddings"""
        logger.info(f"Loading embeddings from {embeddings_path}")
        
        with open(embeddings_path, 'r', encoding='utf-8') as f:
            self.embeddings_data = json.load(f)
        
        # 重建FAISS索引
        first_embedding = np.array(self.embeddings_data[0]['text_embedding'])
        self.index = faiss.IndexFlatL2(len(first_embedding))
        
        for data in self.embeddings_data:
            embedding = np.array(data['text_embedding']).reshape(1, -1)
            self.index.add(embedding)
            
        logger.info(f"Successfully loaded {len(self.embeddings_data)} embeddings")
    
    def get_relevant_docs(self, query: str, k: int = 3) -> List[Dict]:
        """檢索相關文件"""
        model = SentenceTransformer('all-MiniLM-L6-v2')
        query_embedding = model.encode(query)
        
        D, I = self.index.search(query_embedding.reshape(1, -1), k)
        
        relevant_docs = []
        for idx in I[0]:
            relevant_docs.append({
                'content': self.embeddings_data[idx]['content'],
                'metadata': self.embeddings_data[idx]['metadata']
            })
        
        return relevant_docs
    
    def display_relevant_images(self, relevant_docs: List[Dict]):
        """顯示相關文件中的圖片"""
        found_images = False
        for doc in relevant_docs:
            if doc['metadata']['type'] == 'social_post':
                content = doc['content']
                if isinstance(content, dict) and 'images' in content:
                    for img in content['images']:
                        img_path = self.image_dir / img
                        if img_path.exists():
                            display(Image(filename=str(img_path)))
                            print(f"圖片說明: {img}")
                            print("相關討論:")
                            print(f"問題: {content['post']}")
                            print("回應:")
                            for resp in content.get('responses', []):
                                print(f"- {resp}")
                            print("\n")
                            found_images = True
        
        if not found_images:
            print("未找到相關圖片")
    
    def generate_response(self, query: str) -> Tuple[str, List[Dict]]:
        """生成回應並返回相關文件"""
        logger.info(f"Processing query: {query}")
        
        # 獲取相關文件
        relevant_docs = self.get_relevant_docs(query)
        
        # 準備上下文和圖片
        context = "以下是相關的參考資料：\n\n"
        vertex_images = []
        
        for doc in relevant_docs:
            if doc['metadata']['type'] == 'pdf':
                context += f"【醫學文獻】\n{doc['content']}\n\n"
            elif doc['metadata']['type'] == 'social_post':
                context += f"【社群討論】\n問題：{doc['content']['post']}\n"
                if doc['content'].get('responses'):
                    context += "回應：\n"
                    for resp in doc['content']['responses']:
                        context += f"- {resp}\n"
                
                # 處理圖片
                if isinstance(doc['content'], dict) and 'images' in doc['content']:
                    for img in doc['content']['images']:
                        img_path = self.image_dir / img
                        if img_path.exists():
                            vertex_image = VertexImage.load_from_file(str(img_path))
                            if vertex_image:
                                vertex_images.append(vertex_image)
                                context += f"[包含圖片: {img}]\n"
                
                if doc['content'].get('link'):
                    context += f"來源：{doc['content']['link']}\n"
                context += "\n"
        
        # 構建提示
        prompt = f"""你是一位專業的獸醫師，專精於寵物醫療和行為諮詢，特別在老年寵物照護和認知障礙方面有豐富經驗。
請根據提供的參考資料和圖片，針對用戶的問題提供專業的建議。在回答中，請具體描述相關圖片的內容及其對答案的啟發。

參考資料：
{context}

用戶問題：{query}

請提供專業且具體的建議，並請：
1. 描述相關圖片內容及其啟發
2. 引用參考資料來源（醫學文獻/社群討論）
3. 如涉及醫療建議，提醒諮詢獸醫
"""

        # 準備輸入內容
        contents = [prompt]
        if vertex_images:
            contents.extend(vertex_images)
        
        # 生成回應
        response = self.llm.generate_content(contents)
        
        return response.text, relevant_docs

## 4. 系統使用示例
### 4.1 生成新的 Embeddings

In [4]:
# 初始化資料處理器
processor = DataProcessor()

# 處理資料並生成新的embeddings
embedded_docs = processor.process_and_save(
    csv_path="post_response.csv",
    pdf_path="salvin2010.pdf"
)

print(f"成功生成 {len(embedded_docs)} 個文件的 embeddings")

INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: mps
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2
INFO:__main__:Processing PDF: salvin2010.pdf
INFO:__main__:Extracted 5 paragraphs from PDF
INFO:__main__:Processing CSV: post_response.csv
INFO:__main__:Added image: image01.jpg
INFO:__main__:Added image: image02.jpg
INFO:__main__:Processed 3 posts from CSV
Batches: 100%|██████████| 1/1 [00:00<00:00,  3.15it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 56.00it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 57.72it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 47.11it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 39.71it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 40.04it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 41.56it/s]
Batches: 100%|██████████| 1/1 [00:00<00:00, 68.49it/s]
INFO:__main__:Saved 8 embedded documents to embeddings/embeddings.json


成功生成 8 個文件的 embeddings


### 4.2 初始化和測試 RAG 系統

In [5]:
# 初始化 RAG 系統
rag_system = MultiModalRAGSystem(
    project_id="high-tribute-438514-j7",  # 替換為你的 project ID
    location="us-central1",
    image_dir="image"
)

# 測試查詢
def test_query(query: str):
    print("問題：", query)
    print("\n相關圖片：")
    
    # 生成回應並獲取相關文件
    response, relevant_docs = rag_system.generate_response(query)
    
    # 顯示相關圖片
    rag_system.display_relevant_images(relevant_docs)
    
    print("\nGemini的回答：")
    print(response)
    print("\n" + "-"*50 + "\n")

INFO:__main__:Loading embeddings from embeddings/embeddings.json
INFO:__main__:Successfully loaded 8 embeddings


In [None]:
# 進行多個測試查詢
test_queries = [
    "繞圈圈的狗有適合她活動的佈置嗎？",
    "老狗失智症有什麼症狀？",
    "晚上狗狗一直叫該怎麼辦？"
]

for query in test_queries:
    test_query(query)

### 4.3 互動式查詢介面

In [None]:
from IPython.display import clear_output

def interactive_query():
    while True:
        query = input("請輸入您的問題 (輸入'quit'結束): ").strip()
        
        if query.lower() == 'quit':
            break
            
        clear_output(wait=True)
        test_query(query)
        print("\n輸入新的問題或輸入'quit'結束")

# 啟動互動式查詢
interactive_query()