# 基于 Gemma + Milvus 实现多模态 RAG 系统

本练习展示了如何实现一个多模态的检索增强生成（RAG）系统。在检索增强生成中，除了输入提示外，还会使用外部信息源来生成响应。在多模态场景中，最常见的用例之一是将图像纳入响应生成过程。

<br>

本练习通过使用包含图像、文本和表格的 PDF 文件来实现多模态 RAG 系统。这个 PDF 文件就是前面提到的 RAG 定义中的外部信息源。一旦系统设置完成，模型将能够根据所提供的 PDF 文件中的图像、文本和表格生成响应。

以下是本练习涵盖的主题列表：

- 安装依赖项<br>
- 处理 PDF<br>
- 生成多模态嵌入<br>
- 创建向量数据库<br>
- 生成 RAG 响应<br>
- 测试 RAG 工作流<br>

## 安装依赖项

In [None]:
# !pip install -r requirements.txt

## 处理 PDF

In [1]:
import os
import pymupdf
from tqdm import tqdm
from pathlib import Path
import base64
from langchain_text_splitters import RecursiveCharacterTextSplitter


def pdf2imgs(pdf_path, pdf_pages_dir="data/pdf_pages"):
    """
    Convert a PDF file to individual PNG images for each page.

    Args:
        pdf_path (str): The path to the PDF file.
        pdf_pages_dir (str, optional): The directory to save the PNG images. Defaults to "data/pdf_pages".

    Returns:
        str: The path to the directory containing the PNG images.
    """
    import pypdfium2 as pdfium
    # Open the PDF document
    pdf = pdfium.PdfDocument(pdf_path)

    # Create the directory to save the PNG images if it doesn't exist
    os.makedirs(pdf_pages_dir, exist_ok=True)

    # Get the resolution of the first page to determine the scale factor
    resolution = pdf.get_page(0).render().to_numpy().shape
    scale = 1 if max(resolution) >= 1620 else 300 / 72  # Scale factor based on resolution

    # Get the number of pages in the PDF
    n_pages = len(pdf)

    # Loop through each page and save as a PNG image
    for page_number in range(n_pages):
        page = pdf.get_page(page_number)
        pil_image = page.render(
            scale=scale,
            rotation=0,
            crop=(0, 0, 0, 0),
            may_draw_forms=False,
            fill_color=(255, 255, 255, 255),
            draw_annots=False,
            grayscale=False,
        ).to_pil()
        image_path = os.path.join(pdf_pages_dir, f"{str(file_path).split('/')[-1]}_page_{page_number:03d}.png")
        pil_image.save(image_path)

    return pdf_pages_dir


def process_pdf(file_path: str) -> bool:
    """
    Process PDF file:
        1. 抽取文本，然后使用langchain的RecursiveCharacterTextSplitter对文本进行分块，然后保存到text文件夹
        2. 抽取文章的图片，保存到iamges文件夹
        3. 将每页文档转换为图片，保存到page_images文件夹
    """
    try:
        # Define the directories to store the extracted text, images and page images from each page
        filename = Path(file_path.split("/")[-1]) 
        data_dir = Path("./data/")
        doc = pymupdf.open(file_path)
        num_pages = len(doc)
        image_save_dir = data_dir / "images"
        text_save_dir =data_dir / "text"
        page_images_save_dir = data_dir / "page_images"

        # Chunk the text for effective retrieval
        chunk_size = 1000
        overlap=100
        
        items = []
        # Process all pages of the PDF
        for page_num in tqdm(range(num_pages), desc="Processing PDF pages"):
            page = page = doc[page_num]
            text = page.get_text()
            
            # # Process chunks with overlap
            # chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size-overlap)]
            # !pip install -qU langchain-text-splitters 

            # Process chunks with RecursiveCharaterTextSplitter
            text_splitter = RecursiveCharacterTextSplitter(
                separators=[
                    "\n\n",
                    "\n",
                    " ",
                    ".",
                    ",",
                    "。",
                    "．",
                    "、",
                    "，",
                    "\u200b",  # 零宽度空格
                ],
                chunk_size=chunk_size,
                chunk_overlap=overlap
            )
            chunks = text_splitter.split_text(text)

            # Generate an item to add to items
            for i,chunk in enumerate(chunks):
                text_file_name = f"{text_save_dir}/{filename}_text_{page_num}_{i}.txt"
                # If the text folder doesn't exist, create one
                os.makedirs(text_save_dir, exist_ok=True)
                with open(text_file_name, 'w') as f:
                    f.write(chunk)
                
                item={}
                item["page"] = page_num
                item["type"] = "text"
                item["text"] = chunk
                item["path"] = text_file_name
                items.append(item)
            
            
            # Get all the images in the current page
            images = page.get_images()
            for idx, image in enumerate(images):        
                # Extract the image data
                xref = image[0]
                pix = pymupdf.Pixmap(doc, xref)
                pix.tobytes("png")
                # Create the image_name that includes the image path
                image_name = f"{image_save_dir}/{filename}_image_{page_num}_{idx}_{xref}.png"
                # If the image folder doesn't exist, create one
                os.makedirs(image_save_dir, exist_ok=True)
                # Save the image
                pix.save(image_name)
                
                # Produce base64 string
                with open(image_name, 'rb') as f:
                    image = base64.b64encode(f.read()).decode('utf8')
                
                item={}
                item["page"] = page_num
                item["type"] = "image"
                item["path"] = image_name
                item["image"] = image
                items.append(item)

        # Save pdf pages as images
        try:
            page_images_save_dir = pdf2imgs(file_path, page_images_save_dir)
        except Exception as e:
            print(f"Error in processing page image saving.")

        for page_num in range(num_pages):
            page_path = os.path.join(page_images_save_dir,  f"{str(file_path).split('/')[-1]}_page_{page_num:03d}.png")
            # Produce base64 string
            with open(page_path, 'rb') as f:
                page_image = base64.b64encode(f.read()).decode('utf8')
            
            item = {}
            item["page"] = page_num
            item["type"] = "page"
            item["path"] = page_path
            item["image"] = page_image
            items.append(item)
            
        return items
    except Exception as e:
        print(f"Error processing PDF: {str(e)}")
        return None


In [2]:
file_path = "./data/transformers_paper.pdf"
items = process_pdf(file_path)
items[:3]

Processing PDF pages: 100%|██████████| 11/11 [00:00<00:00, 61.68it/s]


[{'page': 0,
  'type': 'text',
  'text': 'Attention Is All You Need\nAshish Vaswani∗\nGoogle Brain\navaswani@google.com\nNoam Shazeer∗\nGoogle Brain\nnoam@google.com\nNiki Parmar∗\nGoogle Research\nnikip@google.com\nJakob Uszkoreit∗\nGoogle Research\nusz@google.com\nLlion Jones∗\nGoogle Research\nllion@google.com\nAidan N. Gomez∗†\nUniversity of Toronto\naidan@cs.toronto.edu\nŁukasz Kaiser∗\nGoogle Brain\nlukaszkaiser@google.com\nIllia Polosukhin∗‡\nillia.polosukhin@gmail.com\nAbstract\nThe dominant sequence transduction models are based on complex recurrent or\nconvolutional neural networks that include an encoder and a decoder. The best\nperforming models also connect the encoder and decoder through an attention\nmechanism. We propose a new simple network architecture, the Transformer,\nbased solely on attention mechanisms, dispensing with recurrence and convolutions\nentirely. Experiments on two machine translation tasks show these models to\nbe superior in quality while being more 

## 生成多模态嵌入

### 加载 Visualized BGE 嵌入模型

我们将使用 Visualized BGE 模型来生成图像和文本的嵌入向量。你需要下载预训练的权重并构建编码器。
```
!git clone https://github.com/FlagOpen/FlagEmbedding.git
!cd FlagEmbedding/research/visual_bge
!pip install -e .
!wget https://huggingface.co/BAAI/bge-visualized/blob/main/Visualized_m3.pth
```

In [3]:
import sys
from pathlib import Path 


class BGEVisualizedEncoder:
    """Manager for BGE Visualized encoders with language support.
    
    Supported languages:
        - 'en': English model (bge-base-en-v1.5)
        - 'multilingual': Multilingual model (bge-m3)
    """
    
    def __init__(self):
        """Initialize English and multilingual encoders."""
        try:
            import sys
            sys.path.append('./FlagEmbedding/research/visual_bge')
            from visual_bge.modeling import Visualized_BGE
            
            self.en_encoder = Visualized_BGE(
                model_name_bge="BAAI/bge-base-en-v1.5",
                model_weight="./models/Visualized_base_en_v1.5.pth"
            ).eval()
            
            self.m3_encoder = Visualized_BGE(
                model_name_bge="BAAI/bge-m3",
                model_weight="./models/Visualized_m3.pth"
            ).eval()
        except Exception as e:
            raise RuntimeError(f"Failed to initialize encoders: {e}")

    def get_encoder(self, language="en"):
        """Get appropriate encoder based on language.
        
        Args:
            language (str): Language option ('en' or 'multilingual')
            
        Returns:
            Visualized_BGE: Initialized encoder model
        """
        return self.en_encoder if language == "en" else self.m3_encoder


def generate_bge_visualized_embeddings(encoder, image_path=None, text=None):
    """Generate embeddings using provided encoder.
    
    Args:
        encoder: Initialized Visualized_BGE model
        image_path (str, optional): Path to input image
        text (str, optional): Text to encode with image
        
    Returns:
        list: Generated embeddings
        
    Raises:
        ValueError: If image_path is not provided
    """
    
    if not image_path and not text:
        raise ValueError("Image path or text must be provided")
        
    if image_path:
        if not Path(image_path).exists():
            raise FileNotFoundError(f"Image not found: {image_path}")
            
    try:
        if image_path and text:                
            return encoder.encode(image=image_path, text=text).tolist()[0]
        if text:
            return encoder.encode(text=text).tolist()[0]

        return encoder.encode(image=image_path).tolist()[0]
    
    except Exception as e:
        raise RuntimeError(f"Encoding failed: {e}")



In [None]:

bge_encoder = BGEVisualizedEncoder()
encoder = bge_encoder.get_encoder(language="multilingual")  # or "en"

for item in tqdm(items, "Generating embeddings"):
    if item['type'] == 'text':
        item['vector'] = generate_bge_visualized_embeddings(encoder=encoder, text=item['text'])
        
    else:
        item['vector'] = generate_bge_visualized_embeddings(encoder=encoder, image_path=item['path'])


Generating embeddings: 100%|██████████| 54/54 [00:12<00:00,  4.42it/s]


## 创建向量数据库


### 将嵌入向量插入 Milvus

接下来，我们将使用 Milvus 向量数据库来存储图像的路径和它们的嵌入向量。

In [6]:
from pymilvus import MilvusClient

# 设置嵌入向量的维度
# dim = len(list(item.values())[0])
dim = len(list(item.values())[-1])

# 设置 Milvus 集合的名称
collection_name = "multimodal_rag_on_pdf"

# 连接到 Milvus 客户端
# 这里我们使用本地的 Milvus 实例，你可以根据你的设置进行更改
milvus_client = MilvusClient(uri='./db/multimodal_rag_milvus_project.db')

# 创建 Milvus 集合
if collection_name not in milvus_client.list_collections():
    milvus_client.create_collection(
        collection_name=collection_name,
        auto_id=True,
        dimension=dim,
        enable_dynamic_field=True,
    )

# 将数据插入到集合中

milvus_client.insert(
    collection_name=collection_name,
    data=items,
)

# # 加载已存在的集合（如果需要重新运行代码）
# milvus_client.load_collection(collection_name)

# # 检查集合状态
# collection_stats = milvus_client.get_load_state(collection_name)
# print(f"Collection status: {collection_stats}")

{'insert_count': 54, 'ids': [456939301895143424, 456939301895143425, 456939301895143426, 456939301895143427, 456939301895143428, 456939301895143429, 456939301895143430, 456939301895143431, 456939301895143432, 456939301895143433, 456939301895143434, 456939301895143435, 456939301895143436, 456939301895143437, 456939301895143438, 456939301895143439, 456939301895143440, 456939301895143441, 456939301895143442, 456939301895143443, 456939301895143444, 456939301895143445, 456939301895143446, 456939301895143447, 456939301895143448, 456939301895143449, 456939301895143450, 456939301895143451, 456939301895143452, 456939301895143453, 456939301895143454, 456939301895143455, 456939301895143456, 456939301895143457, 456939301895143458, 456939301895143459, 456939301895143460, 456939301895143461, 456939301895143462, 456939301895143463, 456939301895143464, 456939301895143465, 456939301895143466, 456939301895143467, 456939301895143468, 456939301895143469, 456939301895143470, 456939301895143471, 45693930189

## 生成RAG回复

In [9]:
from openai import OpenAI

def generate_rag_response(prompt, matched_items):
    
    # Create context
    text_context = ""
    image_context = []
    
    client = OpenAI(base_url="http://localhost:1234/v1", api_key="lm-studio")

    
    
    for item in matched_items:
        if 'text' in item.keys(): 
            text_context += str(item["page"]) + ". " + item['text'] + "\n"
        else:
            image_context.append(item['image'])
    
    final_prompt = f"""You are a helpful assistant for question answering.
    The text context is relevant information retrieved.
    The provided image(s) are relevant information retrieved.
    
    <context>
    {text_context}
    </context>
    
    Answer the following question using the relevant context and images.
    
    <question>
    {prompt}
    </question>
    
    Answer:"""
    
    if image_context:
        response = client.chat.completions.create(
            model="gemma-3-27b-it",
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": final_prompt},
                        {
                            "type": "image_url",
                            "image_url": {"url": f"data:image/jpeg;base64,{image_context}"},
                        },
                    ],
                }
            ],
            max_tokens=500 # 根据需要调整
        )
    else:
        response = client.chat.completions.create(
            model="gemma-3-27b-it",
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": final_prompt},
                    ],
                }
            ],
            max_tokens=500 # 根据需要调整
        )
    result = response.choices[0].message.content

    return result
    

In [None]:
query = "How is the scaled-dot-product attention calculated?"
query_embedding = generate_bge_visualized_embeddings(encoder=encoder, text=query, image_path=None)


# 在 Milvus 中执行搜索
search_results = milvus_client.search(
    collection_name=collection_name,
    data=[query_embedding],
    output_fields=['text', 'page', 'image'],
    search_params={"metric_type": "COSINE", "params": {}},
    limit=3 # 设置返回的搜索结果数量
)[0]

matched_items = [hit["entity"] for hit in search_results]

results= generate_rag_response(None, matched_items)
print(results)



## 未来改进方向

1. 添加更多模态的支持（如视频、音频）
2. 优化PDF图片，表哥抽取性能
3. 优化向量检索策略
3. 提升回答生成的质量
4. 添加结果评估机制
5. 利用FastAPI创建聊天UI





## 参考链接



- BGE-Visualized 项目：https://github.com/FlagOpen/FlagEmbedding
- Hugging Face 模型仓库：https://huggingface.co/BAAI/bge-visualized
- Model Scope模型仓库：https://www.modelscope.cn/models/BAAI/bge-visualized/summary
- Gemma 模型：https://ai.google.dev/gemma
- Milvus 文档：https://milvus.io/docs