# Wikibot RAG Demo
Inspired by LangChain RAG Demo

## Setup

### LangSmith

In [None]:
import getpass
import os

os.environ["LANGSMITH_TRACING"] = "true"
os.environ["LANGSMITH_API_KEY"] = getpass.getpass()

### Chat model (Gemini 2.5 Flash-Lite)

In [None]:
if not os.environ.get("GOOGLE_API_KEY"):
  os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")

from langchain.chat_models import init_chat_model

llm = init_chat_model("gemini-2.5-flash-lite", model_provider="google_genai")

### Embedding model (text-embedding-3-large)

In [None]:
if not os.environ.get("OPENAI_API_KEY"):
  os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter API key for OpenAI: ")

from langchain_openai import OpenAIEmbeddings

embeddings = OpenAIEmbeddings(model="text-embedding-3-large")

### Vector store (FAISS)

In [None]:
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS

embedding_dim = len(embeddings.embed_query("hello world"))
index = faiss.IndexFlatL2(embedding_dim)

vector_store = FAISS(
    embedding_function=embeddings,
    index=index,
    docstore=InMemoryDocstore(),
    index_to_docstore_id={},
)

## Indexing

### Loading documents

In [None]:
import re
import xml.etree.ElementTree as ET
from langchain import hub
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import START, StateGraph
from typing_extensions import List, TypedDict

XML_PATH = "stormlightarchive_pages_current.xml"

try:
    import mwparserfromhell as mw
except ImportError:
    mw = None

def strip_wiki(text: str) -> str:
    if not text:
        return ""
    if mw:
        parsed = mw.parse(text)
        text = parsed.strip_code(normalize=True, collapse=True)
    text = re.sub(r"\[\[(?:[^|\]]*\|)?([^\]]+)\]\]", r"\1", text)
    text = re.sub(r"<ref[^>]*>.*?</ref>", "", text, flags=re.DOTALL)
    text = re.sub(r"<[^>]+>", "", text)
    text = re.sub(r"\s+\n", "\n", text)
    return text.strip()

def normalize_title(title: str) -> str:
    return (title or "").strip().lower().replace(" ", "_")

def iter_pages(xml_path: str):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    ns = ""
    if root.tag.startswith("{"):
        ns = root.tag.split("}")[0] + "}"
    for page in root.findall(f".//{ns}page"):
        title_el = page.find(f"{ns}title")
        rev_el = page.find(f"{ns}revision")
        text_el = rev_el.find(f"{ns}text") if rev_el is not None else None
        title = title_el.text if title_el is not None else ""
        text = text_el.text if text_el is not None else ""
        yield title, text

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,
    chunk_overlap=200,
    separators=["\n\n", "\n", " ", ""],
)

docs: list[Document] = []
skipped = 0
for raw_title, raw_text in iter_pages(XML_PATH):
    if not raw_text:
        continue
    clean = strip_wiki(raw_text)
    if not clean or len(clean) < 200:
        skipped += 1
        continue

    title_norm = normalize_title(raw_title)
    splits = text_splitter.split_text(clean)

    for i, chunk in enumerate(splits):
        docs.append(
            Document(
                page_content=chunk,
                metadata={
                    "source": "stormlight_fandom",
                    "title": raw_title,
                    "title_norm": title_norm,
                    "chunk_id": i,
                },
            )
        )

# Index chunks
_ = vector_store.add_documents(documents=docs)
print(f"Ingested {len(docs)} chunks from Stormlight Fandom XML (skipped {skipped} pages).")

## Retrieval and Generation

In [None]:
# Define prompt for question-answering
# N.B. for non-US LangSmith endpoints, you may need to specify
# api_url="https://api.smith.langchain.com" in hub.pull.
prompt = hub.pull("rlm/rag-prompt")


# Define state for application
class State(TypedDict):
    question: str
    context: List[Document]
    answer: str


# Define application steps
def retrieve(state: State):
    retrieved_docs = vector_store.similarity_search(state["question"])
    return {"context": retrieved_docs}


def generate(state: State):
    docs_content = "\n\n".join(doc.page_content for doc in state["context"])
    messages = prompt.invoke({"question": state["question"], "context": docs_content})
    response = llm.invoke(messages)
    return {"answer": response.content}


# Compile application and test
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()

In [None]:
response = graph.invoke({"question": "Who is Kaladin, according to the first book?"})
print(response["answer"])

In [None]:
response = graph.invoke({"question": "What is the first book in the Stormlight Archive called?"})
print(response["answer"])