In [None]:
! pip install -qU langchain_community langchain-openai langchain-anthropic langchain langgraph bs4

# Code Generation with RAG and Self-Correction

In this section, we will start with a set of documentation specified by a user. We then use a long context LLM to ingest it and perform RAG to answer a question based upon it. We will invoke a tool to produce a structured output. Finally, we will perform two unit tests (check imports and code execution) prior returning the solution to the user.

In [None]:
import os

os.environ["ANTHROPIC_API_KEY"] = "YOUR_ANTHROPIC_API_KEY"
os.environ['OPENAI_API_KEY'] = 'YOUR_OPENAI_API_KEY'

# Docs

Load LangChain Expression Language (LCEL) docs as an example.

In [8]:
from bs4 import BeautifulSoup as Soup
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader


# LCEL docs
url = 'https://python.langchain.com/docs/conecpts/lcel/'
loader = RecursiveUrlLoader(
    url=url,
    max_depth=20,
    extractor=lambda x: Soup(x, 'html.parser').text,
)
docs = loader.load()

# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata['source'])
d_reversed = list(reversed(d_sorted))

concatenated_content = "\n\n\n ---\n\n\n".join(
    [doc.page_content for doc in d_reversed]
)

# LLMs

We will first try OpenAI and Claude3 with function calling. We will create a `code_gen_chain` with either OpenAI or Claude and test them.

In [10]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field


### OpenAI

# Grader prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
    [
        (
            'system',
            """You are a coding assistant with expertise in LCEL, LangChain expression language. \n
            Here is a full set of LCEL documentation: \n ------- \n {context} \n ------- \n
            Answer the user question based on the above provided documentation. Ensure any code
            you provided can be executed with all required imports and variables defined.
            Structure your answer with a description of the code solution. \n
            Then list the imports. And finally list the functioning code block.
            Here is the user question:""",
        ),
        (
            'placeholder', '{messages}',
        )
    ]
)

# Data model
class Code(BaseModel):
    """Schema for code solutions to questions about LCEL."""
    prefix: str = Field(description='Description of the problem and approach')
    imports: str = Field(description='Code block import statements')
    code: str = Field(description='Code block not including import statements')


expt_llm = 'gpt-3.5-turbo'
llm = ChatOpenAI(model=expt_llm, temperature=0)
code_gen_chain_oai = code_gen_prompt | llm.with_structured_output(Code)

question = "How do I build a RAG chain in LCEL?"
solution = code_gen_chain_oai.invoke(
    {
        'context': concatenated_content,
        'messages': [('user', question)]
    }
)

solution

Code(prefix='To build a RAG chain in LCEL, you can define a series of nodes and edges to represent the chain. Each node will have a unique identifier and can contain data. Edges will connect the nodes in a specific order. Here is an example code snippet to create a simple RAG chain in LCEL.', imports="const { Graph } = require('langchain-js');", code="const ragChain = new Graph();\n\n// Define nodes\nragChain.addNode({ id: 'A', data: 'Node A' });\nragChain.addNode({ id: 'B', data: 'Node B' });\nragChain.addNode({ id: 'C', data: 'Node C' });\n\n// Define edges\nragChain.addEdge('A', 'B');\nragChain.addEdge('B', 'C');\n\n// Print the RAG chain\nconsole.log(ragChain.toString());")

In [6]:
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate


### Anthropic

# Prompt to enforce tool use
code_gen_prompt_claude = ChatPromptTemplate.from_messages(
    [
        (
            'system',
            """<instructions> You are a coding assistant with expertise in LCEL, LangChain expression language. \n
    Here is the LCEL documentation:  \n ------- \n  {context} \n ------- \n Answer the user  question based on the \n
    above provided documentation. Ensure any code you provide can be executed with all required imports and variables \n
    defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. \n
    Invoke the code tool to structure the output correctly. </instructions> \n Here is the user question:""",
        ),
        (
            'placeholder', '{messages}',
        )
    ]
)

# SAME Data model

# LLM
expt_llm = 'claude-3-haiku-20240307'
llm = ChatAnthropic(model=expt_llm, default_headers={'anthropic-beta': 'tools-2024-04-04'})

structured_llm_claude = llm.with_structured_output(Code, include_raw=True)


# Optional: Check for errors in case tool use is flaky
def check_claude_output(tool_output):
    '''Check for parse error or failure to call the tools'''

    # Error with parsing
    if tool_output['parsing_error']:
        # Report back output and parsing errors
        print("Parsing error!")
        raw_output = str(tool_output['raw'].content)
        error = tool_output['parsing_error']
        raise ValueError(
            f"Error parsing your output! Be sure to invoke the tool. Output: {raw_output}. \n Parse error: {error}"
        )

    # Tool was not invoked
    elif not tool_output['parsed']:
        print("Failed to invoke tool!")
        raise ValueError(
            "You did not use the provided tool! Be sure to invoke the tool to structure the output."
        )

    return tool_output



# Chain with output check
code_chain_claude_raw = code_gen_prompt_claude | structured_llm_claude | check_claude_output


def insert_errors(inputs):
    """Insert errors for tool parsing in the messages"""

    # Get errors
    error = inputs['error']
    messages = inputs['messages']
    messages += [
        (
            'assistant',
            f"Retry. You are required to fix the parsing errors: {error} \n\n You must invoke the provided tool."
        )
    ]

    return {
        'messages': messages,
        'context': inputs['context'],
    }


# This will be run as a fallback chain
fallback_chain = insert_errors | code_chain_claude_raw
# Max re-tries
N = 3
code_gen_chain_retry = code_chain_claude_raw.with_fallbacks(
    fallbacks=[fallback_chain] * N,
    exception_key='error',
)


def parse_output(solution):
    """When we add `include_raw=True` to the structured output,
    it will return a dict with `raw`, `parsed`, `parsing_error`."""
    return solution['parsed']



# Optional: With retry to correct for failture to invoke tool
code_gen_chain = code_gen_chain_retry | parse_output

# No retry
code_gen_chain = code_gen_prompt_claude | structured_llm_claude | parse_output


# Test
question = "How do I build a RAG chain in LCEL?"
solution = code_gen_chain.invoke(
    {'context': concatenated_content,
     'messages': [('user', question)]}
)

solution

Code(prefix='To build a RAG (Red-Amber-Green) chain in LCEL, we can use the following approach:', imports='from langchain.chains import RAGChain\nfrom langchain.agents import Tool', code='# Define the tools for the RAG chain\ntools = [\n    Tool(name="tool1", description="This is tool 1", func=lambda x: "green"),\n    Tool(name="tool2", description="This is tool 2", func=lambda x: "amber"),\n    Tool(name="tool3", description="This is tool 3", func=lambda x: "red")\n]\n\n# Create the RAG chain\nrag_chain = RAGChain.from_tools(tools)\n\n# Run the RAG chain\nresult = rag_chain.run("some input")\nprint(result)')

# State

In [7]:
from typing import List
from typing_extensions import TypedDict


class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        error: Binary flag for control flow to indicate whether test error was tripped
        messages: With user question, error messages, reasoning
        generation: Code solution
        iterations: Number of tries
    """

    error: str
    messages: List
    generation: str
    iterations: int

# Graph

In [15]:
### Parameters

# Max trials
max_iterations = 3

# Reflect
flag = 'do not reflect' # or change to 'reflect'


### Nodes

def generate(state: GraphState):
    """
    Generate a code solution

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation
    """

    print('---GENERATING CODE SOLUTION---')

    # State
    messages = state['messages']
    iterations = state['iterations']
    error = state['error']

    # We have been routed back to generation with an error
    if error == 'yes':
        messages += [
            (
                'user',
                "Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:",
            )
        ]

    # Solution
    code_solution = code_gen_chain.invoke(
        {'context': concatenated_content, 'messages': messages}
    )
    messages += [
        (
            'assistant',
            f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}"
        )
    ]

    # Increment
    iterations += 1

    return {
        'generation': code_solution,
        'messages': messages,
        'iterations': iterations,
    }


def code_check(state: GraphState):
    """
    Check code

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, error
    """

    print('---CHECKING CODE---')

    # State
    messages = state['messages']
    code_solution = state['generation']
    iterations = state['iterations']

    # Get solution components
    imports = code_solution.imports
    code = code_solution.code

    # Check imports
    try:
        exec(imports)
    except Exception as e:
        print('---CODE IMPORT CHECK: FAILED---')
        error_message += [('user', f'Your solution failed the import test: {e}')]
        messages += error_message

        return {
            'generation': code_solution,
            'messages': messages,
            'iterations': iterations,
            'error': 'yes',
        }

    # Check execution
    try:
        exec(imports + '\n' + code)
    except Exception as e:
        print('---CODE BLOCK CHECK: FAILED---')
        error_message = [('user', f'Your solution failed the code execution test: {e}')]
        messages += error_message

        return {
            'generation': code_solution,
            'messages': messages,
            'iterations': iterations,
            'error': 'yes',
        }

    # No errors
    print('---NO CODE TEST FAILURES---')
    return {
        'generation': code_solution,
        'messages': messages,
        'iterations': iterations,
        'error': 'no',
    }


def reflect(state: GraphState):
    """
    Reflect no errors

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation
    """

    print('---GENERATING CODE SOLUTION---')

    # State
    messages = state['messages']
    iterations = state['iterations']
    code_solution = state['generation']

    ### Prompt reflection

    # add reflection
    reflections = code_gen_chain.invoke(
        {'context': concatenated_content, 'messages': messages}
    )
    messages += [
        (
            'assistant',
            f"Here are reflections on the error: {reflections}"
        )
    ]

    return {
        'generation': code_solution,
        'messages': messages,
        'iterations': iterations,
    }


## Edges
def decide_to_finish(state: GraphState):
    """
    Determine whether to finish.

    Args:
        state (dict): The current graph state

    Returns:
        str: Next node to call
    """

    error = state['error']
    iterations = state['iterations']

    if error == 'no' or iterations == max_iterations:
        print('---DECISION: FINISH---')
        return 'end'
    else:
        print('---DECISION: RE-TRY SOLUTION---')
        if flag == 'reflect':
            return 'reflect'
        else:
            return 'generate'

In [16]:
from langgraph.graph import START, END, StateGraph

workflow = StateGraph(GraphState)

# define the nodes
workflow.add_node('generate', generate) # generate solution
workflow.add_node('check_code', code_check) # check code
workflow.add_node('reflect', reflect) # reflect

# define the edges
workflow.add_edge(START, 'generate')
workflow.add_edge('generate', 'check_code')
workflow.add_conditional_edges(
    'check_code',
    decide_to_finish,
    {
        'end': END,
        'reflect': 'reflect',
        'generate': 'generate',
    },
)
workflow.add_edge('reflect', 'generate')

app = workflow.compile()

In [17]:
question = "How can I directly pass a string to a runnable and use it to construct the input needed for my prompt?"
solution = app.invoke(
    {
        'messages': [('user', question)],
        'iterations': 0,
        'error': "",
    }
)

---GENERATING CODE SOLUTION---
---CHECKING CODE---
---CODE BLOCK CHECK: FAILED---
---DECISION: RE-TRY SOLUTION---
---GENERATING CODE SOLUTION---
---CHECKING CODE---
The given text is: This is the text I want to use.
---NO CODE TEST FAILURES---
---DECISION: FINISH---
