# RAPTOR Advanced Usage

This notebook demonstrates advanced RAPTOR features:
1. LiteLLM integration (100+ providers)
2. HuggingFace embeddings
3. Cross-encoder reranking
4. Custom text splitting
5. Tree inspection
6. Framework integrations (LangChain + LlamaIndex)

## Setup

```bash
pip install raptor-rag[all]
```

In [None]:
import os
# Set your API key (or use LiteLLM's environment variables)
# os.environ["OPENAI_API_KEY"] = "your-key"
# os.environ["ANTHROPIC_API_KEY"] = "your-key"

## 1. LiteLLM Integration

Use any LLM provider with a single model string. LiteLLM supports OpenAI, Anthropic, Cohere, Mistral, local models, and more.

In [None]:
from raptor import (
    RetrievalAugmentation,
    RetrievalAugmentationConfig,
    LiteLLMSummarizationModel,
    LiteLLMQAModel,
)

# Use Anthropic for summarization, OpenAI for QA
config = RetrievalAugmentationConfig(
    summarization_model=LiteLLMSummarizationModel(
        model="anthropic/claude-3-haiku-20240307",
        system_prompt="You are an expert at creating concise summaries.",
    ),
    qa_model=LiteLLMQAModel(
        model="gpt-4o-mini",
        user_prompt_template="Context: {context}\n\nQuestion: {question}\nProvide a detailed answer:",
    ),
)

# ra = RetrievalAugmentation(config=config)
print("LiteLLM config created successfully")

## 2. HuggingFace Embeddings

Use any HuggingFace sentence-transformer model instead of OpenAI embeddings.

In [None]:
from raptor import HuggingFaceEmbeddingModel, RetrievalAugmentationConfig

# BGE-small is fast and high quality
embedding_model = HuggingFaceEmbeddingModel(
    model_name="BAAI/bge-small-en-v1.5",
    device=None,  # Auto-detect GPU/CPU
)

# Test it
emb = embedding_model.create_embedding("Hello world")
print(f"Embedding dimension: {len(emb)}")
print(f"First 5 values: {emb[:5]}")

## 3. Cross-Encoder Reranking

Add a reranking stage after initial retrieval for higher precision.

In [None]:
from raptor import CrossEncoderReRanker, RetrievalAugmentationConfig

# Cross-encoder reranker
reranker = CrossEncoderReRanker(
    model_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
    device=None,
)

config = RetrievalAugmentationConfig(
    tr_reranker=reranker,
    # Retrieve more candidates, then rerank to top results
    tr_top_k=20,
)

print("Reranker config created successfully")

## 4. Custom Text Splitting

Implement your own text splitting strategy.

In [None]:
from raptor import BaseTextSplitter, DefaultTextSplitter

# Default splitter with overlap
overlapping_splitter = DefaultTextSplitter(overlap=2)

# Custom paragraph-based splitter
class ParagraphSplitter(BaseTextSplitter):
    def split_text(self, text, tokenizer, max_tokens):
        paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
        # Merge small paragraphs, split large ones
        chunks = []
        current = []
        current_tokens = 0
        for para in paragraphs:
            para_tokens = len(tokenizer.encode(para))
            if current_tokens + para_tokens > max_tokens and current:
                chunks.append("\n\n".join(current))
                current = []
                current_tokens = 0
            current.append(para)
            current_tokens += para_tokens
        if current:
            chunks.append("\n\n".join(current))
        return chunks

print("Custom splitters created successfully")

## 5. Tree Inspection

Examine the tree structure after building.

In [None]:
import pickle
from pathlib import Path

# Load the demo tree (run 01_basic_usage.ipynb first to generate it)
tree_path = Path("../data/cinderella")
if tree_path.exists():
    with open(tree_path, "rb") as f:
        tree = pickle.load(f)
    
    print(f"Number of layers: {tree.num_layers}")
    print(f"Total nodes: {len(tree.all_nodes)}")
    print(f"Leaf nodes: {len(tree.leaf_nodes)}")
    print(f"Root nodes: {len(tree.root_nodes)}")
    
    print("\nNodes per layer:")
    for layer, nodes in tree.layer_to_nodes.items():
        print(f"  Layer {layer}: {len(nodes)} nodes")
    
    # Inspect a leaf node
    leaf = list(tree.leaf_nodes.values())[0]
    print(f"\nSample leaf node (index {leaf.index}):")
    print(f"  Text: {leaf.text[:100]}...")
    print(f"  Children: {leaf.children}")
    print(f"  Embedding models: {list(leaf.embeddings.keys())}")
else:
    print("Tree not found. Run 01_basic_usage.ipynb first to build it.")

## 6. Framework Integrations

### LangChain

```bash
pip install langchain-raptor-rag
```

In [None]:
# LangChain integration example
# from raptor import RetrievalAugmentation
# from langchain_raptor_rag import RaptorRetriever
#
# ra = RetrievalAugmentation()
# ra.add_documents(text)
# retriever = RaptorRetriever(ra=ra, top_k=10, max_tokens=3500)
# docs = retriever.invoke("What happened to Cinderella?")
# for doc in docs:
#     print(f"Layer {doc.metadata['layer_number']}: {doc.page_content[:80]}...")

print("See langchain-raptor-rag package for full integration")

### LlamaIndex

```bash
pip install llama-index-raptor-rag
```

In [None]:
# LlamaIndex integration example
# from raptor import RetrievalAugmentation
# from llama_index_raptor_rag import RaptorRetriever
# from llama_index.core.query_engine import RetrieverQueryEngine
#
# ra = RetrievalAugmentation()
# ra.add_documents(text)
# retriever = RaptorRetriever(ra=ra)
# query_engine = RetrieverQueryEngine.from_args(retriever)
# response = query_engine.query("What happened to Cinderella?")
# print(response)

print("See llama-index-raptor-rag package for full integration")