# Ovis VLM with RAG (Retrieval-Augmented Generation)

This notebook demonstrates how to add RAG capabilities to the Ovis Vision-Language Model. By integrating a retrieval system, the model can answer questions using information from an external knowledge base, leading to more accurate and detailed responses.

In [4]:
!pip install Pillow sentence-transformers pandas openpyxl chromadb langchain

Collecting langchain
  Downloading langchain-0.3.27-py3-none-any.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting langchain-text-splitters<1.0.0,>=0.3.9
  Downloading langchain_text_splitters-0.3.9-py3-none-any.whl (33 kB)
Collecting langsmith>=0.1.17
  Downloading langsmith-0.4.8-py3-none-any.whl (367 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m368.0/368.0 KB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting SQLAlchemy<3,>=1.4
  Downloading sqlalchemy-2.0.41-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting async-timeout<5.0.0,>=4.0.0
  Downloading async_timeout-4.0.3-py3-none-any.whl (5.7 kB)
Collecting langchain-core<1.0.0,>=0.3.72
  Downloading langchain_core-0

## Step 1: Prepare Knowledge Base

First, we'll create a simple knowledge base. We'll use a few text documents, split them into manageable chunks, and then convert them into vector embeddings using a sentence transformer model. These embeddings will be stored in a FAISS index for fast retrieval.

In [7]:
import pandas as pd
from sentence_transformers import SentenceTransformer
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Load the Excel file
try:
    df = pd.read_excel('/home/aisw/Project/UST-ETRI-2025/VLM_RAG/data/korean_train_20250101.xlsx')
    # Combine all text into a single string
    full_text = "\n".join(df['Text'].dropna().astype(str).tolist())
    print(f"Successfully loaded and combined text from Excel file.")
    
    # 1. Initialize the text splitter
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=500,      # Size of each chunk
        chunk_overlap=50,    # Overlap between chunks
        separators=["\n\n", "\n", " ", ""] # How to split the text
    )

    # 2. Split the text into documents (chunks)
    documents = text_splitter.split_text(full_text)
    
    print(f"Split the text into {len(documents)} documents (chunks).")

except FileNotFoundError:
    print("Error: korean_train_20250101.xlsx not found. Using sample data instead.")
    documents = [
        "광장시장은 대한민국 서울특별시 종로구에 위치한 전통 시장이다.",
        "1905년에 개설되었으며, 대한민국 최초의 상설 시장으로 알려져 있다.",
        "주요 판매 품목은 한복, 직물, 구제 의류, 그리고 다양한 먹거리이다.",
        "특히 빈대떡, 마약김밥, 육회 등이 유명하여 많은 관광객들이 찾는다."
    ]
except Exception as e:
    print(f"An error occurred: {e}. Using sample data.")
    documents = [
        "광장시장은 대한민국 서울특별시 종로구에 위치한 전통 시장이다.",
        "1905년에 개설되었으며, 대한민국 최초의 상설 시장으로 알려져 있다.",
        "주요 판매 품목은 한복, 직물, 구제 의류, 그리고 다양한 먹거리이다.",
        "특히 빈대떡, 마약김밥, 육회 등이 유명하여 많은 관광객들이 찾는다."
    ]

# Load the BGE-M3 multilingual embedding model
embedding_model = SentenceTransformer('BAAI/bge-m3')

# Create embeddings for the documents
doc_embeddings = embedding_model.encode(documents, show_progress_bar=True)

print("Document embeddings created successfully.")
print("Shape of embeddings:", doc_embeddings.shape)

KeyboardInterrupt: 

In [None]:
import chromadb

# Initialize ChromaDB client (in-memory)
client = chromadb.Client()

# Create a new collection or get an existing one
collection_name = "korean_knowledge_base"
collection = client.get_or_create_collection(name=collection_name)

# Generate IDs for each document
doc_ids = [str(i) for i in range(len(documents))]

# Add documents to the collection
collection.add(
    embeddings=doc_embeddings.tolist(),
    documents=documents,
    ids=doc_ids
)

print(f"ChromaDB collection '{collection_name}' created/updated with {collection.count()} documents.")

## Step 2: Implement Retriever

Now, we'll create a retriever function. This function will take a user's query, embed it using the same sentence transformer model, and then search the FAISS index to find the most relevant document chunks.

In [None]:
def retrieve_documents(query, k=2):
    # Embed the query
    query_embedding = embedding_model.encode([query]).tolist()
    
    # Query the collection
    results = collection.query(
        query_embeddings=query_embedding,
        n_results=k
    )
    
    # Return the retrieved documents
    return results['documents'][0]

# Test the retriever with a Korean query
test_query = "광장시장에서 뭘 파나요?"
retrieved = retrieve_documents(test_query)
print(f"Query: {test_query}")
print("Retrieved documents:")
for doc in retrieved:
    print(f"- {doc}")

## Step 3: Load Ovis VLM Model

Next, we load the Ovis VLM model and its tokenizers. This code is adapted from your `main.py` script.

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image

model_path = "AIDC-AI/Ovis2-8B"

torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch_dtype,
    trust_remote_code=True,
    cache_dir="./hf_cache",
    device_map="auto",
    low_cpu_mem_usage=True,
)

tokenizer = model.get_text_tokenizer()
visual_tokenizer = model.get_visual_tokenizer()

print("Ovis VLM model and tokenizers loaded successfully.")

## Step 4: RAG-Enhanced Inference

Finally, we'll combine everything. We'll take an image and a question, use our retriever to find relevant information, construct a new prompt with this context, and then feed it to the Ovis model to get a RAG-enhanced answer.

In [None]:
# Use an image of Gwangjang Market for this example.
import os
image_path = '/home/aisw/Project/UST-ETRI-2025/data/광장시장.jpg' # Please make sure this image exists

if not os.path.exists(image_path):
    print(f"Warning: Image not found at {image_path}. Please upload an image of Gwangjang Market.")
    # As a fallback, create a dummy blank image
    images = [Image.new('RGB', (512, 512), color = 'blue')]
else:
    images = [Image.open(image_path)]

# User's question about the image in Korean
user_question = "이 시장의 역사와 유명한 먹거리는 무엇인가요?"

# 1. Retrieve relevant documents
retrieved_context = retrieve_documents(user_question, k=3)
context_str = "\n".join(retrieved_context)

# 2. Create the augmented prompt
rag_prompt = f"""
제공된 문맥 정보를 사용하여 이미지에 대한 질문에 답하세요.
[문맥]
{context_str}

[질문]
{user_question}
<image>
"""

# 3. Preprocess and run inference
max_partition = 9
prompt, input_ids, pixel_values = model.preprocess_inputs(rag_prompt, images, max_partition=max_partition)
attention_mask = torch.ne(input_ids, tokenizer.pad_token_id)
input_ids = input_ids.unsqueeze(0).to(device=model.device)
attention_mask = attention_mask.unsqueeze(0).to(device=model.device)
if pixel_values is not None:
    pixel_values = pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)
pixel_values = [pixel_values]

# Generate output
with torch.inference_mode():
    gen_kwargs = dict(
        max_new_tokens=1024,
        do_sample=False,
    )
    output_ids = model.generate(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **gen_kwargs)[0]
    output = tokenizer.decode(output_ids, skip_special_tokens=True)

print("--- RAG 기반 답변 ---")
print(f'결과:\n{output}')