In [None]:
# %pip install langchain langchain_community langchain-openai langgraph bs4

In [None]:
import bs4
from IPython.display import HTML, display
from langchain.chat_models import init_chat_model
from langchain_community.document_loaders import WebBaseLoader
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import START, StateGraph
from typing_extensions import List, TypedDict

from lettucedetect_api.client import LettuceClient
from lettucedetect_api.models import TokenDetectionItem

In [None]:
def display_output(predictions: list[TokenDetectionItem]) -> None:
    text = [item.token for item in predictions]
    colors = [f"rgba(255, 0, 0, {item.hallucination_score * 0.8})" for item in predictions]
    html_elements = [
        f'<span style="background-color: {color};">{text}</span>'
        for color, text in zip(colors, text)
    ]
    html = "".join(html_elements)
    display(HTML(html))

In [None]:
# Define Components
llm = init_chat_model("gpt-4o-mini", model_provider="openai")
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
vector_store = InMemoryVectorStore(embeddings)
lettuce_client = LettuceClient("http://127.0.0.1:8000")

system_message = (
    "You are an assistant for question-answering tasks. "
    "Use the following pieces of retrieved context to answer the question. "
    "If you don't know the answer, just say that you don't know. "
    "Always add facts about sea life and related those facts to the context in a funny and creative way. "
    "Don't use emojis."
)
prompt_template = ChatPromptTemplate.from_messages(
    [
        ("system", system_message),
        ("user", "Context: {context}\nQuestion: {question}\n"),
    ],
)

In [None]:
# Load, chunk and index contents of the blog
loader = WebBaseLoader(
    web_paths=("https://lilianweng.github.io/posts/2023-06-23-agent/",),
    bs_kwargs=dict(
        parse_only=bs4.SoupStrainer(class_=("post-content", "post-title", "post-header"))
    ),
)
docs = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
all_splits = text_splitter.split_documents(docs)
_ = vector_store.add_documents(documents=all_splits)

In [None]:
# Define state for application
class State(TypedDict):
    context: List[Document]
    question: str
    answer: str
    hallucination_scores: list


# Define application steps
def retrieve(state: State):
    retrieved_docs = vector_store.similarity_search(state["question"])
    return {"context": retrieved_docs}


def generate(state: State):
    docs_content = "\n\n".join(doc.page_content for doc in state["context"])
    messages = prompt_template.invoke({"question": state["question"], "context": docs_content})
    response = llm.invoke(messages)
    hallucination_scores = lettuce_client.detect_token(
        contexts=[docs_content],
        question=state["question"],
        answer=response.content,
    )
    return {
        "answer": response.content,
        "hallucination_scores": hallucination_scores.predictions,
    }


# Compile application
graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, "retrieve")
graph = graph_builder.compile()

In [None]:
response = graph.invoke({"question": "What is Task Decomposition?"})
display_output(response["hallucination_scores"])