# Hypothetical Document Embeddings RAG

### Setup

In [None]:
% pip install -U langchain-community tiktoken langchain-openai langchainhub chromadb langchain langgraph langchain-text-splitters

In [None]:
import getpass
import os


def _set_env(key: str):
    if key not in os.environ:
        os.environ[key] = getpass.getpass(f"{key}:")


_set_env("OPENAI_API_KEY")
PG_VECTOR_URL = os.environ["PGVECTOR_CONNECTION_STRING"]

### Retrieval

In [None]:
### Build Index
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain.vectorstores.pgvector import PGVector
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings

urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=500, chunk_overlap=50
)
doc_splits = text_splitter.split_documents(docs_list)

# Add to vectorstore

# # ChromaDB
# vectorstore = Chroma.from_documents(
#     documents=doc_splits,
#     collection_name="hyde-rag",
#     embedding=OpenAIEmbeddings(model="text-embedding-3-small"),
# )

vectorstore = PGVector.from_documents(
    documents=doc_splits,
    collection_name="hyde-rag",
    embedding=OpenAIEmbeddings(model="text-embedding-3-small"),
)
retriever = vectorstore.as_retriever()

In [None]:
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import OpenAI

llm_model = OpenAI(max_tokens=250, max_retries=5)

system = """You are an expert content writer. \n
Your task is to generate clear, concise, and relevant text for embedding in documents or systems. \n
Ensure accuracy, readability, and context-appropriate tone in your response."""

hyde_prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Write a brief passage answering the question below. \n\n User question: {question}"),
    ]
)

llm_chain = hyde_prompt | llm_model
hyde_embeddings = HypotheticalDocumentEmbedder(
    llm_chain=llm_chain, base_embeddings=OpenAIEmbeddings(model="text-embedding-3-small")
)


In [None]:
vectorstore = PGVector(
    connection_string=PG_VECTOR_URL,
    collection_name="hyde-rag",
    embedding_function=hyde_embeddings,
)
hyde_retriever = vectorstore.as_retriever()

In [None]:
question = "agent memory"
docs = hyde_retriever.invoke(question)
doc_txt = docs[1].page_content
# print(doc_txt)

### Generate

In [None]:
### Generate

from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI

# Prompt
prompt = hub.pull("rlm/rag-prompt")

# LLM
llm = ChatOpenAI(model_name="gpt-4o", temperature=0)

# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

# Chain
rag_chain = prompt | llm | StrOutputParser()

# Run
generation = rag_chain.invoke({"context": docs, "question": question})
print(generation)

## Graph

### Graph State

In [None]:
from typing import List
from typing_extensions import TypedDict


class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        web_search: whether to add search
        documents: list of documents
    """

    question: str
    generation: str
    documents: List[str]

### Graph Nodes

In [None]:
from langchain.schema import Document


def retrieve(state):
    """
    Retrieve documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    
    print("---RETRIEVE---")
    question = state["question"]

    # Retrieval
    documents = hyde_retriever.invoke(question)
    return {"documents": documents, "question": question}


def generate(state):
    """
    Generate answer

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """
    
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]

    # RAG generation
    generation = rag_chain.invoke({"context": documents, "question": question})
    return {"documents": documents, "question": question, "generation": generation}



### Build Graph

In [None]:
from langgraph.graph import END, StateGraph, START
from langgraph.checkpoint.memory import MemorySaver

memory = MemorySaver()
workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("retrieve", retrieve)
workflow.add_node("generate", generate)

# Build graph
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("generate", END)

# Compile
graph = workflow.compile()

In [None]:
from IPython.display import Image, display

try:
    display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

### Invoke Graph

In [None]:
from pprint import pprint

# Run
inputs = {"question": "What does Lilian Weng say about the types of agent memory?"}
for output in graph.stream(inputs):
    for key, value in output.items():
        # Node
        pprint(f"Node '{key}':")
        # Optional: print full state at each node
        # pprint.pprint(value["keys"], indent=2, width=80, depth=None)
    pprint("\n---\n")

# Final generation
pprint(value["generation"])

In [None]:
graph = workflow.compile(checkpointer=memory)

In [None]:
config = {"configurable": {"thread_id": "1"}}
inputs = {"question": "What does Lilian Weng say about the types of agent memory?"}

messages = graph.invoke(inputs, config)
print("Human:", messages["question"])
print("Ai:", messages["generation"])