# RAG with OpenAI embeddings + FAISS
Build a simple PDF question-answering pipeline using OpenAI embeddings and a local FAISS vector index.

In [None]:
# If a package is missing, uncomment and run the line below.
# %pip install -r requirements.txt

In [None]:

import json
import os
from dataclasses import dataclass
from typing import List, Dict, Any

try:
    import faiss  # from faiss-cpu
except ImportError as exc:
    raise ImportError("faiss is missing. Install with `pip install faiss-cpu` or `%pip install faiss-cpu`." ) from exc

import fitz  # PyMuPDF
import numpy as np
from openai import OpenAI


In [None]:

# --- Configuration ---
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not OPENAI_API_KEY:
    raise ValueError("Set OPENAI_API_KEY as an environment variable before continuing.")

client = OpenAI(api_key=OPENAI_API_KEY)
EMBED_MODEL = "text-embedding-3-small"
CHAT_MODEL = "gpt-4o-mini"

CHUNK_MIN_WORDS = 150
CHUNK_MAX_WORDS = 250


In [None]:

def chunk_text(text: str, min_words: int = CHUNK_MIN_WORDS, max_words: int = CHUNK_MAX_WORDS) -> List[str]:
    """Chunk text by paragraph, keeping boundaries; merge or split to stay within word limits."""
    if max_words <= 0 or min_words <= 0 or min_words > max_words:
        raise ValueError("Invalid min/max word configuration.")

    paragraphs: List[str] = []
    buffer: List[str] = []
    for line in text.splitlines():
        line = line.strip()
        if line:
            buffer.append(line)
        elif buffer:
            paragraphs.append(" ".join(buffer))
            buffer = []
    if buffer:
        paragraphs.append(" ".join(buffer))

    if not paragraphs:
        return []

    def split_long_paragraph(words: List[str]) -> List[str]:
        chunks: List[List[str]] = []
        idx = 0
        n = len(words)
        while idx < n:
            remaining = n - idx
            if remaining > max_words:
                take = max_words
            elif remaining < min_words and chunks:
                chunks[-1].extend(words[idx:])
                break
            else:
                take = remaining
            chunks.append(words[idx:idx + take])
            idx += take
        return [" ".join(chunk) for chunk in chunks]

    chunks: List[str] = []
    current: List[str] = []
    current_words = 0

    def flush_current():
        nonlocal current_words
        if current:
            chunks.append(" ".join(current))
            current = []
            current_words = 0

    for para in paragraphs:
        words = para.split()
        wcount = len(words)

        if wcount > max_words:
            flush_current()
            chunks.extend(split_long_paragraph(words))
            continue

        if not current:
            current = [para]
            current_words = wcount
            continue

        if current_words + wcount <= max_words:
            current.append(para)
            current_words += wcount
            continue

        if current_words < min_words:
            current.append(para)
            current_words += wcount
            flush_current()
        else:
            flush_current()
            current = [para]
            current_words = wcount

    if current_words:
        if current_words < min_words and chunks:
            chunks[-1] = chunks[-1] + " " + " ".join(current)
        else:
            flush_current()

    return chunks


def load_pdf(pdf_path: str) -> List[Dict[str, Any]]:
    """Extract page-level text from a PDF."""
    doc = fitz.open(pdf_path)
    pages = []
    try:
        for page in doc:
            text = page.get_text("text") or ""
            pages.append({"page": page.number + 1, "text": text.strip()})
    finally:
        doc.close()
    return pages


def build_documents(pages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    docs: List[Dict[str, Any]] = []
    for page in pages:
        for idx, chunk in enumerate(chunk_text(page["text"])):
            docs.append({
                "id": f"p{page['page']}_c{idx}",
                "page": page["page"],
                "text": chunk,
            })
    return docs


def embed_texts(texts: List[str]) -> np.ndarray:
    response = client.embeddings.create(model=EMBED_MODEL, input=texts)
    vectors = np.array([item.embedding for item in response.data], dtype="float32")
    return vectors


class FaissStore:
    def __init__(self, dim: int):
        self.index = faiss.IndexFlatL2(dim)
        self.meta: List[Dict[str, Any]] = []

    def add(self, embeddings: np.ndarray, metadatas: List[Dict[str, Any]]):
        if embeddings.shape[0] != len(metadatas):
            raise ValueError("Embeddings and metadata counts do not match.")
        self.index.add(embeddings)
        self.meta.extend(metadatas)

    def search(self, query_embedding: np.ndarray, k: int = 5) -> List[Dict[str, Any]]:
        query_embedding = np.array([query_embedding], dtype="float32")
        distances, indices = self.index.search(query_embedding, k)
        results = []
        for dist, idx in zip(distances[0], indices[0]):
            if idx == -1:
                continue
            results.append({"score": float(dist), **self.meta[idx]})
        return results

    def save(self, index_path: str = "faiss.index", meta_path: str = "metadata.json"):
        faiss.write_index(self.index, index_path)
        with open(meta_path, "w", encoding="utf-8") as f:
            json.dump(self.meta, f, ensure_ascii=False, indent=2)

    @classmethod
    def load(cls, index_path: str = "faiss.index", meta_path: str = "metadata.json") -> "FaissStore":
        index = faiss.read_index(index_path)
        with open(meta_path, "r", encoding="utf-8") as f:
            meta = json.load(f)
        store = cls(index.d)
        store.index = index
        store.meta = meta
        return store


## 1) Ingest a PDF

In [None]:

pdf_path = "./path/to/your.pdf"  # update this to your PDF
pages = load_pdf(pdf_path)
print(f"Loaded {len(pages)} pages from {pdf_path}")


## 2) Chunk and embed

In [None]:

documents = build_documents(pages)
print(f"Prepared {len(documents)} chunks")

embeddings = embed_texts([doc["text"] for doc in documents])
vector_dim = embeddings.shape[1]
store = FaissStore(vector_dim)
store.add(embeddings, documents)
print(f"FAISS index built with dimension {vector_dim}")


## 3) Retrieve

In [None]:

def retrieve(query: str, k: int = 5):
    query_vec = embed_texts([query])[0]
    results = store.search(query_vec, k=k)
    for res in results:
        preview = res['text'][:140].replace('\n', ' ')
        print(f"Page {res['page']} (score={res['score']:.4f}): {preview}...")
    return results

sample_results = retrieve("What is this document about?", k=4)


## 4) Generate answer with retrieved context

In [None]:

def answer(query: str, k: int = 5) -> str:
    hits = retrieve(query, k=k)
    context = "\n\n".join([f"[p{hit['page']}] {hit['text']}" for hit in hits])
    prompt = (
        "You are a concise assistant. Use the provided context to answer the question. "
        "Cite pages in brackets like [p2]. If unsure, say you are not sure.\n\n"
        f"Context:\n{context}\n\nQuestion: {query}"
    )

    response = client.chat.completions.create(
        model=CHAT_MODEL,
        messages=[
            {"role": "system", "content": "Answer using only the provided context."},
            {"role": "user", "content": prompt},
        ],
        temperature=0.2,
    )
    return response.choices[0].message.content

print(answer("Give me a two sentence summary."))


## 5) Persist / reload index (optional)

In [None]:

# Save
store.save("faiss.index", "metadata.json")

# Reload (example)
# store = FaissStore.load("faiss.index", "metadata.json")
