### RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval

RAG systems need to handle lower-level questions that reference specific facts found in a single document or higher-level questions that distill ideas that span many documents. Handling both types of questions can be a challenge with typical k-nearest neighbors (k-NN) retrieval over document chunks.

RAPTOR is an effective strategy that involves creating document summaries that capture higher-level concepts, embedding and clustering those documents, and then summarizing each cluster.

In [1]:
import uuid
from typing import List, Optional, Dict, Any
from dataclasses import dataclass, asdict
from math import ceil

In [4]:
from langchain_community.document_loaders import TextLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from langchain_openai import OpenAI, OpenAIEmbeddings, ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_postgres.vectorstores import PGVector

In [6]:
# CONFIG

CONNECTION = "postgresql+psycopg://langchain:langchain@localhost:6024/langchain"
COLLECTION = "hp_books"

EMBED_MODEL = OpenAIEmbeddings(model="text-embedding-3-small")
SUMMARIZER = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
PROMPT_TEXT = "Summarize following text concisely in 1-2 sentences, preserving key entities and facts:\n\n{text}"
PROMPT = ChatPromptTemplate.from_template(PROMPT_TEXT)
GROUP_SIZE = 6 # branching factor: how many child nodes per parent
TOP_K_ROOT = 3 # how many top root nodes to consider for initial query routing
DESCEND_K = 2 # how many top children to keep at each descent step

In [7]:
# node dataclass to keep metadata

@dataclass
class Node:
    node_id: str
    parent_id: Optional[str]
    level: int # 0 = leaf (original chunk), >0 internal summary level
    text: str
    source: Optional[str] = None
    extra: Dict[str, Any] = None # optional field for debugging

In [8]:
def build_reptor_tree(doc_path: str, chunk_size=1000, chunk_overlap=100, group_size=GROUP_SIZE):
    # load chunks
    loader = TextLoader(doc_path, encoding='utf-8')
    raw_docs = loader.load()
    splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    leaf_chunks: List[Document] = splitter.split_documents(raw_docs)

    # create lead nodes
    leaves: List[Node] = []
    for i, ch in enumerate(leaf_chunks):
        leaves.append(Node(node_id=uuid.uuid4(), parent_id=None, level=0, text=ch.page_content, source=getattr(ch, 'metadata', {}).get('source', doc_path)))

    # recursively create parent summaries
    all_levels = {0: leaves} # dict level -> List[Nodes]
    current_level = 0 # start at level 0
    while len(all_levels[current_level]) > 1:
        children = all_levels[current_level] # get the Nodes associated with current level
        parents: List[Node] = []

        # group children in contiguous group of size group_size
        num_groups = ceil(len(children) / group_size)
        for i in range(num_groups):
            grouped_children = children[i * group_size : (i + 1) * group_size]
            concat_text = '\n\n'.join([ch.text for ch in grouped_children])

            # summarize using LLM
            prompt_filled = PROMPT.format(text=concat_text)
            resp = SUMMARIZER.invoke(prompt_filled)

            # normalize response text
            summary_text = resp.content if hasattr(resp, 'content') else str(resp)
            parent_node = Node(node_id=str(uuid.uuid4()), parent_id=None, level=current_level+1, text=summary_text, source=None)

            # link children to this parent
            for ch in grouped_children:
                ch.parent_id = parent_node.node_id
            
            parents.append(parent_node)
        current_level += 1
        all_levels[current_level] = parents

    # collect all nodes (flatten)
    nodes = []
    max_level = max(all_levels.keys())
    for lvl in range(max_level, -1, -1):
        nodes.extend(all_levels[lvl])

    return nodes, max_level