# Contextual Retrieval Test

This notebook implements and tests contextual retrieval functionality that enhances document chunks with contextual information using the specified prompt.

In [28]:
import os
import re
from typing import List
from pathlib import Path
from langchain.docstore.document import Document
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_ollama import OllamaEmbeddings, ChatOllama
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

## 1. Document Loading and Splitting

In [None]:
def hierarchical_markdown_split(md_text: str, path_prefix: str = "") -> list[Document]:
    """마크다운 문서를 계층적으로 분할합니다."""
    splitter = MarkdownHeaderTextSplitter(headers_to_split_on=[
        ("#", "title"),
        ("##", "section"),
        ("###", "subsection"),
        ("####", "subsubsection"),
        ("#####", "subsubsubsection")
    ])
    docs = splitter.split_text(md_text)

    result_docs = []
    current_title = None
    chunk_idx = 0
    for doc in docs:
        metadata = doc.metadata
        if "title" in metadata:
            current_title = metadata["title"]

        if current_title:
            chunk_idx += 1
            full_title = "" + current_title
            if "section" in metadata:
                full_title += f" / {metadata['section']}"
            if "subsection" in metadata:
                full_title += f" / {metadata['subsection']}"
            if "subsubsection" in metadata:
                full_title += f" / {metadata['subsubsection']}"
            if "subsubsubsection" in metadata:
                full_title += f" / {metadata['subsubsubsection']}"

            content = f"[section_path]: {full_title}\n\n{doc.page_content}"
            doc = Document(page_content=content, metadata={
                **doc.metadata,
                "type": "documentation",
                "source": "dev_center_guide_allmd.md",
                "chunk_idx": chunk_idx
            })

        result_docs.append(doc)

    return result_docs

def load_markdown_file(file_path: str) -> str:
    """마크다운 파일을 로드합니다."""
    with open(file_path, 'r', encoding='utf-8') as file:
        return file.read()

# 마크다운 파일 로드 및 분할
str_md_file = load_markdown_file("../data/dev_center_guide_allmd_touched.md") 
docs_markdown = hierarchical_markdown_split(str_md_file)

print(f"마크다운 문서 분할 완료: {len(docs_markdown)}개 청크")

## 2. Contextual Retrieval Implementation

In [29]:
class ContextualRetrieval:
    def __init__(self, model_name: str = "exaone3.5:latest"):
        self.model_name = model_name
        self.llm = ChatOllama(
            model=model_name,
            temperature=0.1
        )
        
        # Contextual prompt template
        self.contextual_prompt = PromptTemplate.from_template(
            """<document> 
{{WHOLE_DOCUMENT}} 
</document> 
Here is the chunk we want to situate within the whole document 
<chunk> 
{{CHUNK_CONTENT}} 
</chunk> 
Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk. Answer only with the succinct context and nothing else."""
        )
        
        self.chain = self.contextual_prompt | self.llm | StrOutputParser()
    
    def generate_context(self, whole_document: str, chunk_content: str) -> str:
        """Generate contextual information for a chunk within the whole document."""
        try:
            context = self.chain.invoke({
                "WHOLE_DOCUMENT": whole_document,
                "CHUNK_CONTENT": chunk_content
            })
            return context.strip()
        except Exception as e:
            print(f"Error generating context: {e}")
            return ""
    
    def enhance_documents(self, documents: List[Document], whole_document: str) -> List[Document]:
        """Enhance documents with contextual information."""
        enhanced_docs = []
        
        for i, doc in enumerate(documents):
            print(f"Processing document {i+1}/{len(documents)}")
            
            # Generate contextual information
            context = self.generate_context(whole_document, doc.page_content)
            
            # Create enhanced content
            if context:
                enhanced_content = f"[Context]: {context}\n\n{doc.page_content}"
            else:
                enhanced_content = doc.page_content
            
            # Create new document with enhanced content
            enhanced_doc = Document(
                page_content=enhanced_content,
                metadata={
                    **doc.metadata,
                    "original_content": doc.page_content,
                    "contextual_info": context
                }
            )
            
            enhanced_docs.append(enhanced_doc)
        
        return enhanced_docs

## 3. Test Contextual Retrieval on Sample Documents

In [30]:
# Initialize contextual retrieval
contextual_retrieval = ContextualRetrieval(model_name="exaone3.5:latest")

# Test with a small sample first
sample_docs = docs_markdown[:5]  # First 5 documents
print(f"Testing with {len(sample_docs)} sample documents...")

# Enhance documents with contextual information
enhanced_docs = contextual_retrieval.enhance_documents(sample_docs, str_md_file)

print(f"\nEnhanced {len(enhanced_docs)} documents with contextual information")

Testing with 5 sample documents...
Processing document 1/5
Processing document 2/5
Processing document 3/5
Processing document 4/5
Processing document 5/5

Enhanced 5 documents with contextual information


## 4. Compare Original vs Enhanced Documents

In [31]:
# Compare original and enhanced documents
for i, (original, enhanced) in enumerate(zip(sample_docs, enhanced_docs)):
    print(f"\n=== Document {i+1} ===")
    print(f"Original length: {len(original.page_content)}")
    print(f"Enhanced length: {len(enhanced.page_content)}")
    
    # Show the contextual information if available
    if "contextual_info" in enhanced.metadata and enhanced.metadata["contextual_info"]:
        print(f"\nContextual Info: {enhanced.metadata['contextual_info']}")
    
    print(f"\nEnhanced Content Preview:")
    print(enhanced.page_content[:500] + "..." if len(enhanced.page_content) > 500 else enhanced.page_content)
    print("-" * 80)


=== Document 1 ===
Original length: 60
Enhanced length: 141

Contextual Info: Context: Specific application examples in AI technology advancements

Enhanced Content Preview:
[Context]: Context: Specific application examples in AI technology advancements

출처: https://onestore-dev.gitbook.io/dev/tools/billing/v21.md
--------------------------------------------------------------------------------

=== Document 2 ===
Original length: 1221
Enhanced length: 1302

Contextual Info: Context: Specific application examples in AI technology advancements

Enhanced Content Preview:
[Context]: Context: Specific application examples in AI technology advancements

[section_path]: 원스토어 인앱결제 API V7(SDK V21) 연동 안내 및 다운로드

원스토어의 최신 인앱결제 API V7(SDK V21)이 출시되었습니다.  
보다 강력하고 다양한 기능을 지원하는 최신 버전을 적용해보세요.  
{% hint style="info" %}
API V4(SDK V16) 이하 버전과는 호환되지 않습니다. 인앱결제 API V4(SDK V16)에 대한 안내 및 다운로드는 [여기](old-version/v16)를 클릭해주세요.
{% endhint %}  
{% hint style="info" %}
현재 판매중인 앱을 대한민국 외 국가/지역으로 배포하기 위해서는 아래 가이

## 5. Create Vector Database with Enhanced Documents

In [None]:
def embed_and_save(docs: List[Document], output_path: str, model_name: str = "bge-m3:latest"):
    """문서를 임베딩하고 FAISS 데이터베이스로 저장합니다."""
    # 임베딩 모델 초기화
    embedding_model = OllamaEmbeddings(model=model_name)
    
    # FAISS 데이터베이스 생성 및 저장
    db = FAISS.from_documents(docs, embedding_model)
    db.save_local(output_path)
    print(f"✅ 임베딩 저장 완료: {output_path}")

# Create enhanced documents for all documents (or a subset for testing)
print("Creating enhanced documents for all documents...")
all_enhanced_docs = contextual_retrieval.enhance_documents(docs_markdown, str_md_file)

# Save enhanced documents
output_dir = "../models/faiss_contextual_enhanced_" + contextual_retrieval.model_name[:3]
os.makedirs(output_dir, exist_ok=True)
embed_and_save(all_enhanced_docs, output_dir, "bge-m3:latest")

print(f"Total enhanced documents: {len(all_enhanced_docs)}")

## 6. Test Retrieval with Enhanced Documents

In [None]:
# Load the enhanced vector database
embedding_model = OllamaEmbeddings(model="bge-m3:latest")
enhanced_db = FAISS.load_local(
    folder_path=output_dir,
    embeddings=embedding_model,
    allow_dangerous_deserialization=True,
)

enhanced_retriever = enhanced_db.as_retriever(
    search_type="mmr",
    search_kwargs={"k": 10, "fetch_k": 25, "lambda_mult": 0.7}
)

# Test queries
test_queries = [
    "원스토어 인앱결제의 PNS의 개념을 설명해주세요",
    "PNS 메시지 규격의 purchaseState는 어떤 값으로 구성되나요?",
    "원스토어 인앱결제 SDK 사용법",
    "결제 테스트 및 보안 관련 정보"
]

for query in test_queries:
    print(f"\n=== Query: {query} ===")
    results = enhanced_retriever.invoke(query)
    print(f"Retrieved {len(results)} documents")
    
    for i, doc in enumerate(results[:3]):  # Show first 3 results
        print(f"\n--- Result {i+1} ---")
        print(doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content)
    print("=" * 80)

## 7. Comparison with Original Retrieval

In [None]:
# Load original vector database for comparison
original_db = FAISS.load_local(
    folder_path="../models/faiss_vs_rag_iap_v10_1_bge",
    embeddings=embedding_model,
    allow_dangerous_deserialization=True,
)

original_retriever = original_db.as_retriever(
    search_type="mmr",
    search_kwargs={"k": 10, "fetch_k": 25, "lambda_mult": 0.7}
)

# Compare retrieval results
query = "원스토어 인앱결제의 PNS의 개념을 설명해주세요"

print("=== Original Retrieval ===")
original_results = original_retriever.invoke(query)
for i, doc in enumerate(original_results[:3]):
    print(f"\nOriginal Result {i+1}:")
    print(doc.page_content[:200] + "...")

print("\n=== Enhanced Retrieval ===")
enhanced_results = enhanced_retriever.invoke(query)
for i, doc in enumerate(enhanced_results[:3]):
    print(f"\nEnhanced Result {i+1}:")
    print(doc.page_content[:200] + "...")

print("\n=== Comparison Summary ===")
print(f"Original results count: {len(original_results)}")
print(f"Enhanced results count: {len(enhanced_results)}")

# Check if contextual information is present in enhanced results
contextual_count = 0
for doc in enhanced_results:
    if "[Context]:" in doc.page_content:
        contextual_count += 1

print(f"Enhanced documents with contextual info: {contextual_count}/{len(enhanced_results)}")

## 8. Performance Analysis

In [None]:
# Analyze the performance of contextual retrieval
import time

# Test retrieval speed
test_query = "원스토어 인앱결제의 PNS의 개념을 설명해주세요"

print("Testing retrieval performance...")

# Test original retrieval
start_time = time.time()
original_results = original_retriever.invoke(test_query)
original_time = time.time() - start_time

# Test enhanced retrieval
start_time = time.time()
enhanced_results = enhanced_retriever.invoke(test_query)
enhanced_time = time.time() - start_time

print(f"\nPerformance Results:")
print(f"Original retrieval time: {original_time:.4f} seconds")
print(f"Enhanced retrieval time: {enhanced_time:.4f} seconds")
print(f"Time difference: {enhanced_time - original_time:.4f} seconds")

# Analyze result quality
print(f"\nQuality Analysis:")
print(f"Original results: {len(original_results)} documents")
print(f"Enhanced results: {len(enhanced_results)} documents")

# Check for contextual information in enhanced results
contextual_docs = [doc for doc in enhanced_results if "[Context]:" in doc.page_content]
print(f"Enhanced documents with context: {len(contextual_docs)}/{len(enhanced_results)}")

if contextual_docs:
    print(f"\nSample contextual information:")
    sample_context = contextual_docs[0].page_content.split("[Context]:")[1].split("\n\n")[0]
    print(sample_context.strip())