# TreeDex: Tree-Based Document RAG Framework

**Vectorless RAG** — index any document into a navigable tree structure, then retrieve relevant sections using any LLM.

Supports: **Gemini, OpenAI, Claude, Groq, Together AI, Fireworks, vLLM, LM Studio, OpenRouter, Ollama** — any OpenAI-compatible endpoint.

---

## 1. Install Dependencies

In [None]:
!pip install -q pymupdf tiktoken google-generativeai

## 2. Write TreeDex Modules

The following cells write the `treedex/` package directly into the Colab runtime.

In [None]:
import os
os.makedirs("treedex", exist_ok=True)

In [None]:
%%writefile treedex/pdf_parser.py
import fitz  # pymupdf
import tiktoken

_enc = tiktoken.get_encoding("cl100k_base")


def _count_tokens(text: str) -> int:
    return len(_enc.encode(text))


def extract_pages(pdf_path: str) -> list[dict]:
    """Extract text from each page of a PDF.

    Returns a list of dicts with page_num, text, and token_count.
    """
    pages = []
    with fitz.open(pdf_path) as doc:
        for i, page in enumerate(doc):
            text = page.get_text()
            pages.append({
                "page_num": i,
                "text": text,
                "token_count": _count_tokens(text),
            })
    return pages


def pages_to_tagged_text(pages: list[dict], start: int, end: int) -> str:
    """Combine pages[start:end+1] into a string with physical index tags."""
    parts = []
    for page in pages[start : end + 1]:
        n = page["page_num"]
        parts.append(f"<physical_index_{n}>{page['text']}</physical_index_{n}>")
    return "\n".join(parts)


def group_pages(
    pages: list[dict], max_tokens: int = 20000, overlap: int = 1
) -> list[str]:
    """Split pages into token-budget groups, each returned as tagged text."""
    total_tokens = sum(p["token_count"] for p in pages)

    if total_tokens <= max_tokens:
        return [pages_to_tagged_text(pages, 0, len(pages) - 1)]

    groups: list[str] = []
    group_start = 0

    while group_start < len(pages):
        running = 0
        group_end = group_start

        while group_end < len(pages):
            page_tokens = pages[group_end]["token_count"]
            if running + page_tokens > max_tokens and group_end > group_start:
                group_end -= 1
                break
            running += page_tokens
            group_end += 1
        else:
            group_end = len(pages) - 1

        group_end = min(group_end, len(pages) - 1)
        groups.append(pages_to_tagged_text(pages, group_start, group_end))

        if group_end >= len(pages) - 1:
            break

        next_start = group_end + 1 - overlap
        group_start = max(next_start, group_start + 1)

    return groups


def merge_pdfs(pdf_paths: list[str], output_path: str) -> str:
    """Merge multiple PDF files into one. Returns the output path."""
    merged = fitz.open()
    for path in pdf_paths:
        with fitz.open(path) as doc:
            merged.insert_pdf(doc)
    merged.save(output_path)
    merged.close()
    return output_path

In [None]:
%%writefile treedex/tree_builder.py
def list_to_tree(flat_list: list[dict]) -> list[dict]:
    """Convert a flat list with `structure` fields into a hierarchical tree."""
    nodes_by_structure = {}
    roots = []

    for item in flat_list:
        node = {**item, "nodes": []}
        structure = node["structure"]
        nodes_by_structure[structure] = node

        parts = structure.rsplit(".", 1)
        if len(parts) == 1:
            roots.append(node)
        else:
            parent_structure = parts[0]
            parent = nodes_by_structure.get(parent_structure)
            if parent is not None:
                parent["nodes"].append(node)
            else:
                roots.append(node)

    return roots


def _assign_ranges(nodes: list[dict], boundary_end: int):
    """Recursively assign start_index and end_index to nodes."""
    for i, node in enumerate(nodes):
        node["start_index"] = node.get("physical_index", 0)

        if i + 1 < len(nodes):
            node["end_index"] = nodes[i + 1].get("physical_index", 0) - 1
        else:
            node["end_index"] = boundary_end

        if node.get("nodes"):
            _assign_ranges(node["nodes"], node["end_index"])


def assign_page_ranges(tree: list[dict], total_pages: int) -> list[dict]:
    """Set start_index and end_index on each node."""
    _assign_ranges(tree, total_pages - 1)
    return tree


def assign_node_ids(tree: list[dict]) -> list[dict]:
    """DFS traversal, assigns sequential IDs: '0001', '0002', etc."""
    counter = [0]

    def _walk(nodes):
        for node in nodes:
            counter[0] += 1
            node["node_id"] = f"{counter[0]:04d}"
            _walk(node.get("nodes", []))

    _walk(tree)
    return tree


def find_large_nodes(
    tree: list[dict],
    max_pages: int = 10,
    max_tokens: int = 20000,
    pages: list[dict] | None = None,
) -> list[dict]:
    """Return nodes that exceed page or token thresholds."""
    large = []

    def _walk(nodes):
        for node in nodes:
            start = node.get("start_index", 0)
            end = node.get("end_index", 0)
            page_count = end - start + 1

            is_large = page_count > max_pages

            if not is_large and pages is not None:
                token_sum = sum(
                    p["token_count"]
                    for p in pages
                    if start <= p["page_num"] <= end
                )
                is_large = token_sum > max_tokens

            if is_large:
                large.append(node)

            _walk(node.get("nodes", []))

    _walk(tree)
    return large


def embed_text_in_tree(tree: list[dict], pages: list[dict]) -> list[dict]:
    """Add `text` field to each node by concatenating page text for its range."""

    def _walk(nodes):
        for node in nodes:
            start = node.get("start_index", 0)
            end = node.get("end_index", 0)
            node["text"] = "\n".join(
                p["text"] for p in pages if start <= p["page_num"] <= end
            )
            _walk(node.get("nodes", []))

    _walk(tree)
    return tree

In [None]:
%%writefile treedex/tree_utils.py
import copy
import json
import re


def create_node_mapping(tree: list[dict]) -> dict:
    """Flatten tree into {node_id: node_dict} for O(1) lookups."""
    mapping = {}

    def _walk(nodes):
        for node in nodes:
            if "node_id" in node:
                mapping[node["node_id"]] = node
            _walk(node.get("nodes", []))

    _walk(tree)
    return mapping


def strip_text_from_tree(tree: list[dict]) -> list[dict]:
    """Return a deep copy of the tree with all `text` fields removed."""
    stripped = copy.deepcopy(tree)

    def _strip(nodes):
        for node in nodes:
            node.pop("text", None)
            _strip(node.get("nodes", []))

    _strip(stripped)
    return stripped


def collect_node_texts(node_ids: list[str], node_map: dict) -> str:
    """Gather and concatenate text from a list of node IDs."""
    parts = []
    for nid in node_ids:
        node = node_map.get(nid)
        if node is None:
            continue
        title = node.get("title", "Untitled")
        structure = node.get("structure", "")
        text = node.get("text", "")
        header = f"[{structure}: {title}]" if structure else f"[{title}]"
        parts.append(f"{header}\n{text}")
    return "\n\n".join(parts)


def count_nodes(tree: list[dict]) -> int:
    """Recursively count total nodes in the tree."""
    total = 0
    for node in tree:
        total += 1
        total += count_nodes(node.get("nodes", []))
    return total


def get_leaf_nodes(tree: list[dict]) -> list[dict]:
    """Return all nodes with empty `nodes` list."""
    leaves = []

    def _walk(nodes):
        for node in nodes:
            children = node.get("nodes", [])
            if not children:
                leaves.append(node)
            else:
                _walk(children)

    _walk(tree)
    return leaves


def tree_to_flat_list(tree: list[dict]) -> list[dict]:
    """Flatten hierarchy back to a list in DFS order."""
    result = []

    def _walk(nodes):
        for node in nodes:
            flat = {k: v for k, v in node.items() if k != "nodes"}
            result.append(flat)
            _walk(node.get("nodes", []))

    _walk(tree)
    return result


def extract_json(text: str):
    """Robust JSON extraction from LLM responses."""
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass

    match = re.search(r"```(?:json)?\s*\n?(.*?)```", text, re.DOTALL)
    if match:
        block = match.group(1).strip()
        try:
            return json.loads(block)
        except json.JSONDecodeError:
            cleaned = re.sub(r",\s*([}\]])", r"\1", block)
            try:
                return json.loads(cleaned)
            except json.JSONDecodeError:
                pass

    for start_char, end_char in [("{", "}"), ("[", "]")]:
        start = text.find(start_char)
        if start == -1:
            continue
        depth = 0
        for i in range(start, len(text)):
            if text[i] == start_char:
                depth += 1
            elif text[i] == end_char:
                depth -= 1
                if depth == 0:
                    candidate = text[start : i + 1]
                    try:
                        return json.loads(candidate)
                    except json.JSONDecodeError:
                        cleaned = re.sub(r",\s*([}\]])", r"\1", candidate)
                        try:
                            return json.loads(cleaned)
                        except json.JSONDecodeError:
                            break

    raise ValueError(f"Could not extract JSON from text: {text[:200]}...")


def print_tree(tree: list[dict], indent: int = 0):
    """Pretty-print tree structure for debugging."""
    prefix = "  " * indent
    for node in tree:
        node_id = node.get("node_id", "????")
        structure = node.get("structure", "")
        title = node.get("title", "Untitled")
        start = node.get("start_index", "?")
        end = node.get("end_index", "?")
        print(f"{prefix}[{node_id}] {structure}: {title} (pages {start}-{end})")
        print_tree(node.get("nodes", []), indent + 1)

In [None]:
%%writefile treedex/loaders.py
"""Document loaders for multiple file formats."""

import os
import re
from html.parser import HTMLParser

import tiktoken

_enc = tiktoken.get_encoding("cl100k_base")


def _count_tokens(text: str) -> int:
    return len(_enc.encode(text))


def _text_to_pages(text: str, chars_per_page: int = 3000) -> list[dict]:
    """Split plain text into synthetic pages by character count."""
    pages = []
    for i in range(0, len(text), chars_per_page):
        chunk = text[i : i + chars_per_page]
        pages.append({
            "page_num": len(pages),
            "text": chunk,
            "token_count": _count_tokens(chunk),
        })
    return pages


class PDFLoader:
    """Load PDF files using PyMuPDF."""

    def load(self, path: str) -> list[dict]:
        from treedex.pdf_parser import extract_pages
        return extract_pages(path)


class TextLoader:
    """Load plain text or markdown files."""

    def __init__(self, chars_per_page: int = 3000):
        self.chars_per_page = chars_per_page

    def load(self, path: str) -> list[dict]:
        with open(path, "r", encoding="utf-8") as f:
            text = f.read()
        return _text_to_pages(text, self.chars_per_page)


class _HTMLStripper(HTMLParser):
    """Simple HTML-to-text converter using stdlib."""

    def __init__(self):
        super().__init__()
        self._parts: list[str] = []
        self._skip = False

    def handle_starttag(self, tag, attrs):
        if tag in ("script", "style"):
            self._skip = True

    def handle_endtag(self, tag):
        if tag in ("script", "style"):
            self._skip = False
        if tag in ("p", "div", "br", "h1", "h2", "h3", "h4", "h5", "h6", "li", "tr"):
            self._parts.append("\n")

    def handle_data(self, data):
        if not self._skip:
            self._parts.append(data)

    def get_text(self) -> str:
        raw = "".join(self._parts)
        return re.sub(r"\n{3,}", "\n\n", raw).strip()


class HTMLLoader:
    """Load HTML files, stripping tags to plain text (stdlib only)."""

    def __init__(self, chars_per_page: int = 3000):
        self.chars_per_page = chars_per_page

    def load(self, path: str) -> list[dict]:
        with open(path, "r", encoding="utf-8") as f:
            html = f.read()
        stripper = _HTMLStripper()
        stripper.feed(html)
        text = stripper.get_text()
        return _text_to_pages(text, self.chars_per_page)


class DOCXLoader:
    """Load DOCX files using python-docx."""

    def __init__(self, chars_per_page: int = 3000):
        self.chars_per_page = chars_per_page

    def load(self, path: str) -> list[dict]:
        import docx

        doc = docx.Document(path)
        text = "\n".join(p.text for p in doc.paragraphs)
        return _text_to_pages(text, self.chars_per_page)


_EXTENSION_MAP = {
    ".pdf": PDFLoader,
    ".txt": TextLoader,
    ".md": TextLoader,
    ".html": HTMLLoader,
    ".htm": HTMLLoader,
    ".docx": DOCXLoader,
}


def auto_loader(path: str) -> list[dict]:
    """Auto-detect file format and load pages."""
    ext = os.path.splitext(path)[1].lower()
    loader_cls = _EXTENSION_MAP.get(ext)
    if loader_cls is None:
        raise ValueError(
            f"Unsupported file extension '{ext}'. "
            f"Supported: {', '.join(_EXTENSION_MAP)}"
        )
    return loader_cls().load(path)

In [None]:
%%writefile treedex/llm_backends.py
"""LLM backends for TreeDex.

Named providers (Gemini, OpenAI, Claude) lazy-import their SDKs.
OpenAICompatibleLLM and OllamaLLM use only stdlib (urllib).
"""

import json
import urllib.request
import urllib.error
from abc import ABC, abstractmethod


class BaseLLM(ABC):
    """Base class for all LLM backends."""

    @abstractmethod
    def generate(self, prompt: str) -> str:
        """Send a prompt and return the generated text."""

    def __repr__(self):
        return f"{self.__class__.__name__}()"


class GeminiLLM(BaseLLM):
    """Google Gemini via google-generativeai SDK."""

    def __init__(self, api_key: str, model: str = "gemini-2.0-flash"):
        self.api_key = api_key
        self.model_name = model
        self._client = None

    def _get_client(self):
        if self._client is None:
            import google.generativeai as genai

            genai.configure(api_key=self.api_key)
            self._client = genai.GenerativeModel(self.model_name)
        return self._client

    def generate(self, prompt: str) -> str:
        model = self._get_client()
        response = model.generate_content(prompt)
        return response.text

    def __repr__(self):
        return f"GeminiLLM(model={self.model_name!r})"


class OpenAILLM(BaseLLM):
    """OpenAI via openai SDK."""

    def __init__(self, api_key: str, model: str = "gpt-4o"):
        self.api_key = api_key
        self.model_name = model
        self._client = None

    def _get_client(self):
        if self._client is None:
            import openai

            self._client = openai.OpenAI(api_key=self.api_key)
        return self._client

    def generate(self, prompt: str) -> str:
        client = self._get_client()
        response = client.chat.completions.create(
            model=self.model_name,
            messages=[{"role": "user", "content": prompt}],
        )
        return response.choices[0].message.content

    def __repr__(self):
        return f"OpenAILLM(model={self.model_name!r})"


class ClaudeLLM(BaseLLM):
    """Anthropic Claude via anthropic SDK."""

    def __init__(self, api_key: str, model: str = "claude-sonnet-4-20250514"):
        self.api_key = api_key
        self.model_name = model
        self._client = None

    def _get_client(self):
        if self._client is None:
            import anthropic

            self._client = anthropic.Anthropic(api_key=self.api_key)
        return self._client

    def generate(self, prompt: str) -> str:
        client = self._get_client()
        response = client.messages.create(
            model=self.model_name,
            max_tokens=4096,
            messages=[{"role": "user", "content": prompt}],
        )
        return response.content[0].text

    def __repr__(self):
        return f"ClaudeLLM(model={self.model_name!r})"


class OpenAICompatibleLLM(BaseLLM):
    """Universal backend for any OpenAI-compatible API endpoint.

    Works with: Groq, Together AI, Fireworks, vLLM, LM Studio,
    OpenRouter, Ollama (OpenAI mode), and any other compatible service.

    Uses only stdlib (urllib) — zero SDK dependencies.
    """

    def __init__(
        self,
        base_url: str,
        model: str,
        api_key: str | None = None,
        max_tokens: int = 4096,
        temperature: float = 0.0,
    ):
        self.base_url = base_url.rstrip("/")
        self.model = model
        self.api_key = api_key
        self.max_tokens = max_tokens
        self.temperature = temperature

    def generate(self, prompt: str) -> str:
        url = f"{self.base_url}/chat/completions"

        payload = {
            "model": self.model,
            "messages": [{"role": "user", "content": prompt}],
            "max_tokens": self.max_tokens,
            "temperature": self.temperature,
        }

        data = json.dumps(payload).encode("utf-8")

        headers = {"Content-Type": "application/json"}
        if self.api_key:
            headers["Authorization"] = f"Bearer {self.api_key}"

        req = urllib.request.Request(url, data=data, headers=headers, method="POST")

        try:
            with urllib.request.urlopen(req, timeout=120) as resp:
                body = json.loads(resp.read().decode("utf-8"))
        except urllib.error.HTTPError as e:
            error_body = e.read().decode("utf-8", errors="replace")
            raise RuntimeError(
                f"API request failed ({e.code}): {error_body}"
            ) from e

        return body["choices"][0]["message"]["content"]

    def __repr__(self):
        return f"OpenAICompatibleLLM(base_url={self.base_url!r}, model={self.model!r})"


class OllamaLLM(BaseLLM):
    """Ollama native backend using /api/generate endpoint.

    Uses only stdlib (urllib) — zero SDK dependencies.
    """

    def __init__(
        self,
        model: str = "llama3",
        base_url: str = "http://localhost:11434",
    ):
        self.model = model
        self.base_url = base_url.rstrip("/")

    def generate(self, prompt: str) -> str:
        url = f"{self.base_url}/api/generate"

        payload = {
            "model": self.model,
            "prompt": prompt,
            "stream": False,
        }

        data = json.dumps(payload).encode("utf-8")
        headers = {"Content-Type": "application/json"}

        req = urllib.request.Request(url, data=data, headers=headers, method="POST")

        try:
            with urllib.request.urlopen(req, timeout=120) as resp:
                body = json.loads(resp.read().decode("utf-8"))
        except urllib.error.HTTPError as e:
            error_body = e.read().decode("utf-8", errors="replace")
            raise RuntimeError(
                f"Ollama request failed ({e.code}): {error_body}"
            ) from e

        return body["response"]

    def __repr__(self):
        return f"OllamaLLM(model={self.model!r})"

In [None]:
%%writefile treedex/prompts.py
"""Prompt templates for structure extraction and retrieval."""

STRUCTURE_EXTRACTION_PROMPT = """\
You are a document structure analyzer. Given the following document text with \
physical page index tags, extract the hierarchical structure (table of contents).

Return a JSON list of objects, each with:
- "structure": hierarchical numbering like "1", "1.1", "1.2.3"
- "title": the section/chapter title
- "physical_index": the page number (from the <physical_index_N> tag) where this section starts

Rules:
- Use the physical_index tags to determine page numbers
- Create a logical hierarchy: chapters -> sections -> subsections
- Every section must have a unique structure ID
- Return ONLY valid JSON — no extra text

Document text:
{text}

JSON output:
"""

STRUCTURE_CONTINUE_PROMPT = """\
You are continuing to extract the hierarchical structure of a document.

Here is the structure extracted so far:
{previous_structure}

Now extract the structure from the next portion of the document. \
Continue the numbering from where the previous structure left off. \
If a section from the previous portion continues into this portion, \
do NOT duplicate it.

Return a JSON list of NEW sections only (same format as before).

Document text:
{text}

JSON output:
"""

RETRIEVAL_PROMPT = """\
You are a document retrieval system. Given a document's tree structure and a \
user query, select the most relevant sections that would contain the answer.

Document structure:
{tree_structure}

User query: {query}

Return a JSON object with:
- "node_ids": list of node IDs (strings like "0001", "0005") that are most \
relevant to the query
- "reasoning": brief explanation of why these sections were selected

Select the smallest set of sections that fully covers the answer. \
Prefer leaf nodes over parent nodes when the leaf contains the specific content. \
Return ONLY valid JSON.

JSON output:
"""

In [None]:
%%writefile treedex/core.py
"""TreeDex: Tree-based document RAG framework."""

import json
import os

from treedex.loaders import auto_loader, PDFLoader
from treedex.pdf_parser import group_pages
from treedex.tree_builder import (
    assign_node_ids,
    assign_page_ranges,
    embed_text_in_tree,
    find_large_nodes,
    list_to_tree,
)
from treedex.tree_utils import (
    collect_node_texts,
    count_nodes,
    create_node_mapping,
    extract_json,
    get_leaf_nodes,
    print_tree,
    strip_text_from_tree,
)
from treedex.prompts import (
    STRUCTURE_EXTRACTION_PROMPT,
    STRUCTURE_CONTINUE_PROMPT,
    RETRIEVAL_PROMPT,
)


class QueryResult:
    """Result of a TreeDex query."""

    def __init__(self, context: str, node_ids: list[str],
                 page_ranges: list, reasoning: str):
        self.context = context
        self.node_ids = node_ids
        self.page_ranges = page_ranges
        self.reasoning = reasoning

    @property
    def pages_str(self) -> str:
        """Human-readable page ranges like 'pages 5-8, 12-15'."""
        if not self.page_ranges:
            return "no pages"
        parts = []
        for start, end in self.page_ranges:
            if start == end:
                parts.append(str(start + 1))
            else:
                parts.append(f"{start + 1}-{end + 1}")
        return "pages " + ", ".join(parts)

    def __repr__(self):
        return (
            f"QueryResult(nodes={self.node_ids}, {self.pages_str}, "
            f"context_len={len(self.context)})"
        )


class TreeDex:
    """Tree-based document index for RAG retrieval."""

    def __init__(self, tree: list[dict], pages: list[dict],
                 llm=None):
        self.tree = tree
        self.pages = pages
        self.llm = llm
        self._node_map = create_node_mapping(tree)

    @classmethod
    def from_file(cls, path: str, llm, loader=None,
                  max_tokens: int = 20000, overlap: int = 1,
                  verbose: bool = True):
        """Build a TreeDex index from a file."""
        if verbose:
            print(f"Loading: {os.path.basename(path)}")

        if loader is not None:
            pages = loader.load(path)
        else:
            pages = auto_loader(path)

        if verbose:
            total_tokens = sum(p["token_count"] for p in pages)
            print(f"  {len(pages)} pages, {total_tokens:,} tokens")

        return cls.from_pages(pages, llm, max_tokens=max_tokens,
                              overlap=overlap, verbose=verbose)

    @classmethod
    def from_pages(cls, pages: list[dict], llm,
                   max_tokens: int = 20000, overlap: int = 1,
                   verbose: bool = True):
        """Build a TreeDex index from pre-extracted pages."""
        groups = group_pages(pages, max_tokens=max_tokens, overlap=overlap)

        if verbose:
            print(f"  {len(groups)} page group(s) for structure extraction")

        all_sections = []
        for i, group_text in enumerate(groups):
            if verbose:
                print(f"  Extracting structure from group {i + 1}/{len(groups)}...")

            if i == 0:
                prompt = STRUCTURE_EXTRACTION_PROMPT.format(text=group_text)
            else:
                prev_json = json.dumps(all_sections, indent=2)
                prompt = STRUCTURE_CONTINUE_PROMPT.format(
                    previous_structure=prev_json, text=group_text
                )

            response = llm.generate(prompt)
            sections = extract_json(response)

            if isinstance(sections, list):
                all_sections.extend(sections)
            elif isinstance(sections, dict) and "sections" in sections:
                all_sections.extend(sections["sections"])

        if verbose:
            print(f"  Extracted {len(all_sections)} sections")

        tree = list_to_tree(all_sections)
        assign_page_ranges(tree, total_pages=len(pages))
        assign_node_ids(tree)
        embed_text_in_tree(tree, pages)

        if verbose:
            print(f"  Tree: {count_nodes(tree)} nodes")

        return cls(tree, pages, llm)

    @classmethod
    def from_tree(cls, tree: list[dict], pages: list[dict], llm=None):
        """Create a TreeDex from an existing tree and pages."""
        return cls(tree, pages, llm)

    def query(self, question: str, llm=None) -> QueryResult:
        """Query the index and return relevant context."""
        active_llm = llm or self.llm
        if active_llm is None:
            raise ValueError("No LLM provided. Pass llm= to query() or TreeDex constructor.")

        stripped = strip_text_from_tree(self.tree)
        tree_json = json.dumps(stripped, indent=2)

        prompt = RETRIEVAL_PROMPT.format(
            tree_structure=tree_json, query=question
        )

        response = active_llm.generate(prompt)
        result = extract_json(response)

        node_ids = result.get("node_ids", [])
        reasoning = result.get("reasoning", "")

        context = collect_node_texts(node_ids, self._node_map)

        page_ranges = []
        for nid in node_ids:
            node = self._node_map.get(nid)
            if node:
                start = node.get("start_index", 0)
                end = node.get("end_index", 0)
                page_ranges.append((start, end))

        return QueryResult(
            context=context,
            node_ids=node_ids,
            page_ranges=page_ranges,
            reasoning=reasoning,
        )

    def save(self, path: str) -> str:
        """Save the index to a JSON file."""
        stripped = strip_text_from_tree(self.tree)

        data = {
            "version": "1.0",
            "framework": "TreeDex",
            "tree": stripped,
            "pages": self.pages,
        }

        with open(path, "w", encoding="utf-8") as f:
            json.dump(data, f, indent=2)

        return path

    @classmethod
    def load(cls, path: str, llm=None):
        """Load a TreeDex index from a JSON file."""
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)

        tree = data["tree"]
        pages = data["pages"]

        assign_page_ranges(tree, total_pages=len(pages))
        embed_text_in_tree(tree, pages)

        return cls(tree, pages, llm)

    def show_tree(self):
        """Pretty-print the tree structure."""
        print_tree(self.tree)

    def stats(self) -> dict:
        """Return index statistics."""
        total_tokens = sum(p["token_count"] for p in self.pages)
        leaves = get_leaf_nodes(self.tree)
        return {
            "total_pages": len(self.pages),
            "total_tokens": total_tokens,
            "total_nodes": count_nodes(self.tree),
            "leaf_nodes": len(leaves),
            "root_sections": len(self.tree),
        }

    def find_large_sections(self, max_pages: int = 10,
                            max_tokens: int = 20000) -> list[dict]:
        """Find sections that exceed size thresholds."""
        return find_large_nodes(
            self.tree, max_pages=max_pages,
            max_tokens=max_tokens, pages=self.pages
        )

In [None]:
%%writefile treedex/__init__.py
"""TreeDex: Tree-based document RAG framework."""

from treedex.core import TreeDex, QueryResult
from treedex.loaders import PDFLoader, TextLoader, HTMLLoader, DOCXLoader, auto_loader
from treedex.llm_backends import (
    GeminiLLM,
    OpenAILLM,
    ClaudeLLM,
    OllamaLLM,
    OpenAICompatibleLLM,
)

__version__ = "0.1.0"

__all__ = [
    "TreeDex",
    "QueryResult",
    "PDFLoader",
    "TextLoader",
    "HTMLLoader",
    "DOCXLoader",
    "auto_loader",
    "GeminiLLM",
    "OpenAILLM",
    "ClaudeLLM",
    "OllamaLLM",
    "OpenAICompatibleLLM",
]

---
## 3. Setup LLM

Choose your LLM backend. Examples below for **Gemini** (default) and **OpenAI-compatible** endpoints.

In [None]:
from treedex import TreeDex, GeminiLLM, OpenAICompatibleLLM

# --- Option A: Gemini (default) ---
from google.colab import userdata
GEMINI_KEY = userdata.get("GEMINI_API_KEY")
llm = GeminiLLM(api_key=GEMINI_KEY)

# --- Option B: Groq via OpenAI-compatible ---
# llm = OpenAICompatibleLLM(
#     base_url="https://api.groq.com/openai/v1",
#     api_key="gsk_...",
#     model="llama-3.3-70b-versatile"
# )

# --- Option C: Together AI ---
# llm = OpenAICompatibleLLM(
#     base_url="https://api.together.xyz/v1",
#     api_key="...",
#     model="meta-llama/Llama-3-70b-chat-hf"
# )

# --- Option D: Local Ollama ---
# from treedex import OllamaLLM
# llm = OllamaLLM(model="llama3")

print(f"Using: {llm}")

## 4. Upload & Index a PDF

In [None]:
from google.colab import files as colab_files
uploaded = colab_files.upload()
pdf_name = list(uploaded.keys())[0]
print(f"Uploaded: {pdf_name}")

In [None]:
# Build the index
index = TreeDex.from_file(pdf_name, llm=llm)

## 5. Inspect the Index

In [None]:
# Show the tree structure
index.show_tree()

In [None]:
# View stats
stats = index.stats()
for k, v in stats.items():
    print(f"  {k}: {v}")

## 6. Save & Load

In [None]:
# Save
index.save("my_index.json")
print("Saved to my_index.json")

# Load
index2 = TreeDex.load("my_index.json", llm=llm)
print(f"Loaded: {index2.stats()['total_nodes']} nodes")

## 7. Query the Index

In [None]:
# Ask a question
result = index.query("What are the main topics covered in this document?")

print(f"Relevant nodes: {result.node_ids}")
print(f"Source: {result.pages_str}")
print(f"Reasoning: {result.reasoning}")
print(f"\nContext ({len(result.context)} chars):")
print(result.context[:500])

In [None]:
# Another query
result2 = index.query("Explain the key concepts in the first chapter.")

print(f"Nodes: {result2.node_ids}")
print(f"Source: {result2.pages_str}")
print(f"Reasoning: {result2.reasoning}")

## 8. Swap LLM Provider

Demonstrate using a different LLM for queries (e.g., Groq).

In [None]:
# Swap to Groq for fast inference (uncomment and add your key)
# groq_llm = OpenAICompatibleLLM(
#     base_url="https://api.groq.com/openai/v1",
#     api_key="gsk_YOUR_KEY_HERE",
#     model="llama-3.3-70b-versatile"
# )
#
# # Query with the new LLM — same index, different brain
# result3 = index.query("Summarize the introduction.", llm=groq_llm)
# print(f"Groq result: {result3.node_ids}")
# print(f"Reasoning: {result3.reasoning}")

print("Uncomment the code above and add your Groq API key to try it!")