# LangGraph + DSPy

In [1]:
# Install the dependencies if needed.
# %pip install -U dspy-ai
# %pip install -U openai jinja2
# %pip install -U langchain langchain-community langchain-openai langchain-core langgraph

### 0) source code modification

```
class LangChainPredict(Predict, Runnable):
    def __init__(self, prompt, llm, state_graph=False, **config):
        self.state_g = state_graph


    def forward():
        if self.state_graph:
            return {"langpredict_output": output}   
```

state_graph parameter is added to return a dictionary to update the state

### 1) Setting Up

In [2]:
import dspy

from dspy.evaluate.evaluate import Evaluate
from dspy.teleprompt import BootstrapFewShotWithRandomSearch

colbertv2 = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')

dspy.configure(rm=colbertv2)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from langchain_openai import OpenAI
from langchain.cache import SQLiteCache

llm = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0)

In [4]:
# def retrieve(state):
#     print("==> retrieve")
#     result = dspy.Retrieve(k=5)(state["question"]).passages
#     print("result", result)
#     return {"context": result}
retrieve = lambda x: dspy.Retrieve(k=5)(x["question"]).passages

### 2) Defining a graph as a `LangGraph` expression

In [5]:
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

from dspy.predict.langchain import LangChainPredict

prompt = PromptTemplate.from_template(
    "Given {context}, answer the question `{question}` as a tweet."
)

# ! LangChainPredict not working with StateGraph since it has to return dict
# Modified DSPy langchain.py to make it work with StateGraphs
# use state_graph=True to return a dict
generate_answer = LangChainPredict(prompt, llm, state_graph=True)

In [6]:
def decide_to_use_retrieve(state):

    use_retrieve = False
    if state["cached_questions"] is None:
        return "retrieve"
    if state["question"] not in state["cached_questions"]:
        use_retrieve = True

    if use_retrieve:
        print("Using retrieve")
        return "retrieve"
    else:
        print("Not using retrieve")
        return "generate_answer"

In [7]:
from typing import TypedDict

class State(TypedDict):
    question: str
    context: str
    output: str
    cached_questions: list
    langpredict_output: str

In [8]:
from langgraph.graph import END, StateGraph

g = StateGraph(State)

g.add_node("start", RunnablePassthrough())
g.add_conditional_edges("start", decide_to_use_retrieve)

g.add_node("retrieve", RunnablePassthrough.assign(context=retrieve))
g.add_edge("retrieve", "generate_answer")

g.add_node("generate_answer", generate_answer)
g.add_edge("generate_answer", END)

g.set_entry_point("start")

compiled_g = g.compile()

In [9]:
# compiled_graph.invoke(
#     {
#         "question": "what's the capital of Korea",
#         "cached_questions": [
#             "what's the capital of France",
#             "what's the capital of Germany",
#         ],
#     }
# )

### 3) Converting the chain into a **DSPy module**

In [10]:
# From DSPy, import the modules that know how to interact with LangChain LCEL.
from dspy.predict.langchain import LangChainModule

graph_dspy_module = LangChainModule(compiled_graph)

In [11]:
for name, node in compiled_graph.get_graph().nodes.items():
    print(f"{node.data}")

<class 'pydantic.v1.main.LangGraphInput'>
<class 'pydantic.v1.main.LangGraphOutput'>

mapper={
  context: RunnableLambda(lambda x: dspy.Retrieve(k=5)(x['question']).passages)
}
LangChainPredict(Template(Essential Instructions: Respond to the given question based on the provided context in the style of a tweet. The response should be concise, engaging, and limited to the character count typical for a tweet (up to 280 characters)., ['Context:', 'Question:', 'Tweet Response:']))


In [12]:
graph_dspy_module.modules

[LangChainPredict(Template(Essential Instructions: Respond to the given question based on the provided context in the style of a tweet. The response should be concise, engaging, and limited to the character count typical for a tweet (up to 280 characters)., ['Context:', 'Question:', 'Tweet Response:']))]

In [13]:
# graph_dspy_module.invoke(
#     {
#         "question": "what's the capital of Korea",
#         # "cached_questions": [
#         #     "what's the capital of France",
#         #     "what's the capital of Germany",
#         # ],
#     }
# )

### 4) Trying the module


In [14]:
# question = "In what region was Eddy Mazzoleni born?"

# graph_dspy_module.invoke({"question": question})

### 5) Optimizing the module

In [15]:
# We took the liberty to define this metric and load a few examples from a standard QA dataset.
# Let's impore them from `tweet_metric.py` in the same directory that contains this notebook.
from tweet_metric import metric, trainset, valset, devset

# We loaded 200, 50, and 150 examples for training, validation (tuning), and development (evaluation), respectively.
# You could load less (or more) and, chances are, the right DSPy optimizers will work well for many problems.
len(trainset), len(valset), len(devset)

  table = cls._concat_blocks(blocks, axis=0)


(200, 50, 150)

In [19]:
trainset = trainset[:1]
valset = valset[:1]

In [20]:
optimizer = BootstrapFewShotWithRandomSearch(metric=metric, max_bootstrapped_demos=1, num_candidate_programs=1)

In [21]:
# Now use the optimizer to *compile* the chain. This could take 5-10 minutes, unless it's cached.
optimized_chain = optimizer.compile(graph_dspy_module, trainset=trainset, valset=valset)

LangChainModule forward kwargs =
 {'question': 'Are both Cangzhou and Qionghai in the Hebei province of China?'}
LangGraph Input:  {'question': 'Are both Cangzhou and Qionghai in the Hebei province of China?'}
RunnableCallable.invoke INPUT:  {'question': 'Are both Cangzhou and Qionghai in the Hebei province of China?'}
RunnableCallable.invoke INPUT:  <object object at 0x1272ee5d0>
RunnableCallable.invoke input:  <object object at 0x1272ee5d0>


ERROR:dspy.evaluate.evaluate:[2m2024-05-14T20:41:30.621738Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Expected dict, got <object object at 0x1272ee5d0>[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m147[0m
Average Metric: 0.0 / 1  (0.0): 100%|██████████| 1/1 [00:00<00:00, 1293.74it/s]


LangChainModule forward kwargs =
 {'question': 'Are both Cangzhou and Qionghai in the Hebei province of China?'}
LangGraph Input:  {'question': 'Are both Cangzhou and Qionghai in the Hebei province of China?'}
RunnableCallable.invoke INPUT:  {'question': 'Are both Cangzhou and Qionghai in the Hebei province of China?'}
RunnableCallable.invoke INPUT:  <object object at 0x1272ee5b0>
RunnableCallable.invoke input:  <object object at 0x1272ee5b0>


ERROR:dspy.evaluate.evaluate:[2m2024-05-14T20:41:30.650805Z[0m [[31m[1merror    [0m] [1mError for example in dev set: 		 Expected dict, got <object object at 0x1272ee5b0>[0m [[0m[1m[34mdspy.evaluate.evaluate[0m][0m [36mfilename[0m=[35mevaluate.py[0m [36mlineno[0m=[35m147[0m
Average Metric: 0.0 / 1  (0.0): 100%|██████████| 1/1 [00:00<00:00, 1153.23it/s]
Average Metric: 0.0 / 4  (0.0):  40%|████      | 4/10 [00:46<01:09, 11.63s/it]


AttributeError: 'Template' object has no attribute 'equals'

In [20]:
kw = {
    "question": "Are both Chico Municipal Airport and William R. Fairchild International Airport in California?"
}

print(type(dict(**kw)))

<class 'dict'>


In [1]:
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError
from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.output_parsers.openai_tools import PydanticToolsParser


class Reflection(BaseModel):
    missing: str = Field(description="Critique of what is missing.")
    superfluous: str = Field(description="Critique of what is superfluous")


class AnswerQuestion(BaseModel):
    """Answer the question. Provide an answer, reflection, and then follow up with search queries to improve the answer."""

    answer: str = Field(description="~250 word detailed answer to the question.")
    reflection: Reflection = Field(description="Your reflection on the initial answer.")
    search_queries: list[str] = Field(
        description="1-3 search queries for researching improvements to address the critique of your current answer."
    )

AnswerQuestion.__name__

'AnswerQuestion'

In [1]:
from langchain.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain_openai import OpenAI

model = OpenAI(model_name="gpt-3.5-turbo-instruct", temperature=0.0)


# Define your desired data structure.
class Joke(BaseModel):
    setup: str = Field(description="question to set up a joke")
    punchline: str = Field(description="answer to resolve the joke")

    # You can add custom validation logic easily with Pydantic.
    @validator("setup")
    def question_ends_with_question_mark(cls, field):
        if field[-1] != "?":
            raise ValueError("Badly formed question!")
        return field


# Set up a parser + inject instructions into the prompt template.
parser = PydanticOutputParser(pydantic_object=Joke)

prompt = PromptTemplate(
    template="Answer the user query.\n{format_instructions}\n{query}\n",
    input_variables=["query"],
    partial_variables={"format_instructions": parser.get_format_instructions()},
)


In [2]:
parser.get_format_instructions()

'The output should be formatted as a JSON instance that conforms to the JSON schema below.\n\nAs an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}\nthe object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted.\n\nHere is the output schema:\n```\n{"properties": {"setup": {"title": "Setup", "description": "question to set up a joke", "type": "string"}, "punchline": {"title": "Punchline", "description": "answer to resolve the joke", "type": "string"}}, "required": ["setup", "punchline"]}\n```'

In [None]:

# And a query intended to prompt a language model to populate the data structure.
prompt_and_model = prompt | model
output = prompt_and_model.invoke({"query": "Tell me a joke."})
parser.invoke(output)