In [7]:
from rich import print
import inspect
from typing import List, Dict, Any
from langgraph.graph import StateGraph
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
from langchain_openai import ChatOpenAI
import textwrap
from dotenv import load_dotenv
from langchain_groq import ChatGroq
# Usage


load_dotenv("env")

class NodeFunctionUpdater:
    def __init__(self, llm):
        self.llm = llm

    def update(self, node_name: str, original_function: str, score: float, feedback: str) -> str:
        print(f"Updating node: {node_name}")
        print(f"Original function:\n{original_function}")
        print(f"Score: {score}")
        print(f"Feedback: {feedback}")
        prompt_template = f"""
        You are an expert in optimizing Python functions. I have a function that needs to be improved based on the 
        performance score and feedback provided. Please rewrite the function following the template provided below.
        Original function:
        ```
        {original_function}
        ```
        Performance score: {score}
        Feedback: {feedback}
        
        The improved function should follow this template:

        def {node_name}(state: AgentState) -> AgentState:
            def updated_node() -> ChatPromptTemplate:
                examples = [
                    {{
                        "input": textwrap.dedent('''
                        [HERE GOES THE INPUT 1]
                        '''),
                        "output": textwrap.dedent('''
                        [HERE GOES THE OUTPUT 1]            
                        ''')
                    }},
                    {{
                        "input": textwrap.dedent('''
                        [HERE GOES THE INPUT 2]
                        '''),
                        "output": textwrap.dedent('''
                        [HERE GOES THE OUTPUT 2] 
                        ''')
                    }}
                ]
                
                example_prompt = ChatPromptTemplate.from_messages(
                    [
                        ("human", "{{input}}"),
                        ("ai", "{{output}}"),
                    ]
                )
                
                few_shot_prompt = FewShotChatMessagePromptTemplate(
                    example_prompt=example_prompt,
                    examples=examples,
                )
                system_prompt = textwrap.dedent('''
                [HERE GOES THE SYSTEM PROMPT]
                ''')
                final_prompt = ChatPromptTemplate.from_messages(
                    [
                        ("system", system_prompt),
                        few_shot_prompt,
                        ("human", "{{input}}"),
                    ]
                )
                return final_prompt

            prompt = updated_node()
            chain = prompt | self.llm
            result = chain.invoke({{"input": state['requirements']}})
            return {{"email": result.content}}
        """
        
        prompt = ChatPromptTemplate.from_template(prompt_template)
        chain = prompt | self.llm
        response = chain.invoke({
            "original_function": original_function,
            "score": score,
            "feedback": feedback
        })
        print("LLM response content:")
        print(response.content)
        return response.content

class GraphAgentOptimizer:
    def __init__(self, original_graph: StateGraph, ground_truth: List[Dict[str, Any]], llm):
        self.original_graph = original_graph
        self.ground_truth = ground_truth
        self.llm = llm
        self.node_functions = {}
        self.node_performances = {}
        self.node_updater = NodeFunctionUpdater(llm)

    def load_graph(self):
        print("Loading graph...")
        for node_name, node_func in self.original_graph.nodes.items():
            self.node_functions[node_name] = self.extract_function_source(node_func)
            self.node_performances[node_name] = {'score': 0, 'feedback': ''}
        print(f"Loaded {len(self.node_functions)} nodes.")

    def extract_function_source(self, func):
        if hasattr(func, 'func'):  # For RunnableCallable objects
            return inspect.getsource(func.func)
        elif callable(func):
            return inspect.getsource(func)
        else:
            return str(func)  # Fallback for other types

    def optimize(self, num_iterations: int):
        for iteration in range(num_iterations):
            print(f"Iteration {iteration + 1}")
            for i, example in enumerate(self.ground_truth):
                print(f"  Processing example {i + 1}")
                try:
                    output = self.forward_pass(example['input'])
                    self.evaluate_and_update(output, example['output'])
                except Exception as e:
                    print(f"Error processing example {i + 1}: {str(e)}")
            self.update_node_functions()

    def forward_pass(self, input_data: str) -> Dict[str, Any]:
        print("Performing forward pass...")
        state = {"requirements": input_data}
        for node_name in self.original_graph.nodes:
            print(f"  Executing node: {node_name}")
            node_func = self.get_node_function(node_name)
            state.update(node_func(state))
        print("Forward pass completed.")
        return state

    def evaluate_and_update(self, generated_output: Dict[str, Any], ground_truth_output: str):
        print("Evaluating output...")
        prompt = ChatPromptTemplate.from_template("""
        Compare the generated output with the ground truth output.
        Provide a score from 0 to 1 for each node in the graph,
        where 1 is perfect and 0 is completely wrong.
        Also provide a brief explanation for each score.

        Generated output:
        {generated_output}

        Ground truth output:
        {ground_truth_output}

        Response format:
        [node_name] score: [score]
        [node_name] feedback: [feedback]
        (Repeat for each node)
        """)
        chain = prompt | self.llm
        response = chain.invoke({
            "generated_output": str(generated_output),
            "ground_truth_output": ground_truth_output
        })

        print("Raw LLM response:")
        print(response.content)
        
        lines = response.content.split('\n')
        for node_name in self.node_functions.keys():
            score_line = next((line for line in lines if f"{node_name} score:" in line), None)
            feedback_line = next((line for line in lines if f"{node_name} feedback:" in line), None)
            
            if score_line and feedback_line:
                try:
                    score = float(score_line.split(':')[1].strip())
                except ValueError:
                    print(f"Warning: Could not parse score for {node_name}. Using default score of 0.5")
                    score = 0.5
                
                feedback = feedback_line.split(':', 1)[1].strip() if ':' in feedback_line else ''
                
                self.node_performances[node_name]['score'] = (self.node_performances[node_name]['score'] + score) / 2
                self.node_performances[node_name]['feedback'] += feedback + '\n'
            else:
                print(f"Warning: Could not find score or feedback for {node_name}")

    def update_node_functions(self):
        print("Updating node functions...")
        for node_name, performance in self.node_performances.items():
            if performance['score'] < 0.8:
                print(f"Updating function for node: {node_name}")
                updated_function = self.node_updater.update(
                    node_name,
                    self.node_functions[node_name],
                    performance['score'],
                    performance['feedback']
                )
                self.node_functions[node_name] = updated_function
                print(f"Updated {node_name}")
            else:
                print(f"Node {node_name} performance is satisfactory. No update needed.")

    def get_node_function(self, node_name: str) -> callable:
        function_code = self.node_functions[node_name]
        try:
            exec(function_code, globals())
            return eval(node_name)
        except Exception as e:
            print(f"Error executing function for node {node_name}: {str(e)}")
            print("Function code:")
            print(function_code)
            raise

    def export_optimized_graph(self) -> StateGraph:
        print("Exporting optimized graph...")
        optimized_graph = StateGraph(self.original_graph.state_type)
        for node_name, node_func in self.node_functions.items():
            optimized_graph.add_node(node_name, self.get_node_function(node_name))
        
        for edge in self.original_graph.edges:
            optimized_graph.add_edge(edge.start, edge.end)
        
        for conditional_edge in self.original_graph.conditional_edges:
            optimized_graph.add_conditional_edges(
                conditional_edge.start,
                conditional_edge.condition,
                conditional_edge.edge_map
            )
        
        optimized_graph.set_entry_point(self.original_graph.entry_point)
        
        print("Optimized graph exported.")
        return optimized_graph.compile()



# llm_name = "llama-3.1-8b-instant"
llm_name = "llama-3.1-70b-versatile"
llm = ChatGroq(cache=False, temperature=0.0, model_name=llm_name)
optimizer = GraphAgentOptimizer(graph, ground_truth, llm)
optimizer.load_graph()
optimized_graph = optimizer.optimize(num_iterations=3)

True

In [11]:
import inspect
from typing import Dict, List, Any, Callable
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate, FewShotChatMessagePromptTemplate
from langgraph.graph import StateGraph, END
import textwrap

class NodeFunctionUpdater:
    def __init__(self, llm):
        self.llm = llm

    def update(self, node_name: str, original_function: str, score: float, feedback: str) -> str:
        prompt_template = f"""
        You are an expert in optimizing Python functions. I have a function that needs to be improved based on the 
        performance score and feedback provided. Please rewrite the function following the template provided below.
        Original function:
        ```
        {original_function}
        ```
        Performance score: {score}
        Feedback: {feedback}
        
        The improved function should follow this template:

        def {node_name}(state: Dict[str, Any]) -> Dict[str, Any]:
            def updated_node() -> ChatPromptTemplate:
                examples = [
                    {{
                        "input": textwrap.dedent('''
                        [HERE GOES THE INPUT 1]
                        '''),
                        "output": textwrap.dedent('''
                        [HERE GOES THE OUTPUT 1]            
                        ''')
                    }},
                    {{
                        "input": textwrap.dedent('''
                        [HERE GOES THE INPUT 2]
                        '''),
                        "output": textwrap.dedent('''
                        [HERE GOES THE OUTPUT 2] 
                        ''')
                    }}
                ]
                
                example_prompt = ChatPromptTemplate.from_messages(
                    [
                        ("human", "{{input}}"),
                        ("ai", "{{output}}"),
                    ]
                )
                
                few_shot_prompt = FewShotChatMessagePromptTemplate(
                    example_prompt=example_prompt,
                    examples=examples,
                )
                system_prompt = textwrap.dedent('''
                [HERE GOES THE SYSTEM PROMPT]
                ''')
                final_prompt = ChatPromptTemplate.from_messages(
                    [
                        ("system", system_prompt),
                        few_shot_prompt,
                        ("human", "{{input}}"),
                    ]
                )
                return final_prompt

            prompt = updated_node()
            chain = prompt | self.llm
            result = chain.invoke({{"input": state['requirements']}})
            return {{"{node_name}_output": result.content}}
        """
        
        prompt = ChatPromptTemplate.from_template(prompt_template)
        chain = prompt | self.llm
        response = chain.invoke({
            "original_function": original_function,
            "score": score,
            "feedback": feedback
        })
        return response.content

class GraphAgentOptimizer:
    def __init__(self, original_graph: StateGraph, ground_truth: List[Dict[str, Any]], llm: ChatGroq):
        self.original_graph = original_graph
        self.ground_truth = ground_truth
        self.llm = llm
        self.node_functions = {}
        self.node_performances = {}
        self.node_updater = NodeFunctionUpdater(llm)

    def load_graph(self):
        print("Loading graph...")
        for node_name, node_func in self.original_graph.nodes.items():
            self.node_functions[node_name] = self.extract_function_source(node_func)
            self.node_performances[node_name] = {'score': 0, 'feedback': ''}
        print(f"Loaded {len(self.node_functions)} nodes.")

    def extract_function_source(self, func):
        if hasattr(func, 'func'):  # For RunnableCallable objects
            return inspect.getsource(func.func)
        elif callable(func):
            return inspect.getsource(func)
        else:
            return str(func)  # Fallback for other types

    def optimize(self, num_iterations: int):
        self.load_graph()
        for iteration in range(num_iterations):
            print(f"Iteration {iteration + 1}")
            for i, example in enumerate(self.ground_truth):
                print(f"  Processing example {i + 1}")
                try:
                    output = self.forward_pass(example['input'])
                    self.evaluate_and_update(output, example['output'])
                except Exception as e:
                    print(f"Error processing example {i + 1}: {str(e)}")
            self.update_node_functions()

    def forward_pass(self, input_data: str) -> Dict[str, Any]:
        print("Performing forward pass...")
        state = {"requirements": input_data}
        for node_name in self.original_graph.nodes:
            print(f"  Executing node: {node_name}")
            node_func = self.get_node_function(node_name)
            state.update(node_func(state))
        print("Forward pass completed.")
        return state

    def evaluate_and_update(self, generated_output: Dict[str, Any], ground_truth_output: str):
        print("Evaluating output...")
        prompt = ChatPromptTemplate.from_template("""
        Compare the generated output with the ground truth output.
        Provide a score from 0 to 1 for each node in the graph,
        where 1 is perfect and 0 is completely wrong.
        Also provide a brief explanation for each score.

        Generated output:
        {generated_output}

        Ground truth output:
        {ground_truth_output}

        Response format:
        [node_name] score: [score]
        [node_name] feedback: [feedback]
        (Repeat for each node)
        """)
        chain = prompt | self.llm
        response = chain.invoke({
            "generated_output": str(generated_output),
            "ground_truth_output": ground_truth_output
        })
        
        lines = response.content.split('\n')
        for node_name in self.node_functions.keys():
            score_line = next((line for line in lines if f"{node_name} score:" in line), None)
            feedback_line = next((line for line in lines if f"{node_name} feedback:" in line), None)
            
            if score_line and feedback_line:
                score = float(score_line.split(':')[1].strip())
                feedback = feedback_line.split(':', 1)[1].strip()
                
                self.node_performances[node_name]['score'] = (self.node_performances[node_name]['score'] + score) / 2
                self.node_performances[node_name]['feedback'] += feedback + '\n'
            else:
                print(f"Warning: Could not find score or feedback for {node_name}")

    def update_node_functions(self):
        print("Updating node functions...")
        for node_name, performance in self.node_performances.items():
            if performance['score'] < 0.8:
                print(f"Updating function for node: {node_name}")
                updated_function = self.node_updater.update(
                    node_name,
                    self.node_functions[node_name],
                    performance['score'],
                    performance['feedback']
                )
                self.node_functions[node_name] = updated_function
                print(f"Updated {node_name}")
            else:
                print(f"Node {node_name} performance is satisfactory. No update needed.")

    def get_node_function(self, node_name: str) -> Callable:
        function_code = self.node_functions[node_name]
        try:
            exec(function_code, globals())
            return eval(node_name)
        except Exception as e:
            print(f"Error executing function for node {node_name}: {str(e)}")
            print("Function code:")
            print(function_code)
            raise

    def export_optimized_graph(self) -> StateGraph:
        print("Exporting optimized graph...")
        optimized_graph = StateGraph(self.original_graph.state_type)
        for node_name, node_func in self.node_functions.items():
            optimized_graph.add_node(node_name, self.get_node_function(node_name))
        
        for edge in self.original_graph.edges:
            optimized_graph.add_edge(edge.start, edge.end)
        
        optimized_graph.set_entry_point(self.original_graph.entry_point)
        
        print("Optimized graph exported.")
        return optimized_graph.compile()

# Example usage
def write_email(state: Dict[str, str]) -> Dict[str, str]:
    # Simplified email writing
    email = f"""
    Dear Client,
    Based on your requirements: {state['requirements']}
    We propose the following tasks:
    1. Analyze requirements
    2. Develop solution
    3. Test and deploy
    Best regards,
    AI Team
    """
    return {"email": email}

# Create original graph
original_graph = StateGraph(Dict[str, Any])
original_graph.add_node("write_email", write_email)
original_graph.add_edge("write_email", END)
original_graph.set_entry_point("write_email")

# Ground truth data
ground_truth = [
    {
        "input": "Develop a website with user authentication",
        "output": """
        Dear Client,
        Based on your requirements: Develop a website with user authentication
        We propose the following tasks:
        1. Design user authentication system
        2. Develop frontend and backend
        3. Implement security measures
        4. Test and deploy website
        Best regards,
        AI Team
        """
    },
    {
        "input": "Create a mobile app for task management",
        "output": """
        Dear Client,
        Based on your requirements: Create a mobile app for task management
        We propose the following tasks:
        1. Design app UI/UX
        2. Develop task management features
        3. Implement data synchronization
        4. Test on multiple devices
        5. Deploy to app stores
        Best regards,
        AI Team
        """
    },
    {
        "input": "Set up a data analytics pipeline",
        "output": """
        Dear Client,
        Based on your requirements: Set up a data analytics pipeline
        We propose the following tasks:
        1. Analyze data sources and requirements
        2. Design data pipeline architecture
        3. Implement data collection and processing
        4. Set up analytics and visualization tools
        5. Test and optimize pipeline performance
        Best regards,
        AI Team
        """
    }
]

# Initialize LLM and optimizer
llm = ChatGroq(cache=False, temperature=0.2, model_name="llama-3.1-8b-instant")
optimizer = GraphAgentOptimizer(original_graph, ground_truth, llm)

# Run optimization
optimizer.optimize(num_iterations=3)

# Get optimized graph
optimized_graph = optimizer.export_optimized_graph()

# Test the optimized graph
test_input = "Develop an e-commerce platform with payment integration"
result = optimized_graph.invoke({"requirements": test_input})
print("\nTest Result:")
print(result["email"])

KeyError: 'Input to ChatPromptTemplate is missing variables {\'output\', \'"write_email_output"\', "state[\'requirements\']", \'"input"\', \'"email"\', \'input\', \'\\n                        "input"\'}.  Expected: [\'\\n                        "input"\', \'"email"\', \'"input"\', \'"write_email_output"\', \'input\', \'output\', "state[\'requirements\']"] Received: [\'original_function\', \'score\', \'feedback\']'

In [35]:
from typing import TypedDict
from langgraph.graph import StateGraph, END
from dotenv import load_dotenv
from langchain_groq import ChatGroq

load_dotenv("env")

llm_name = "llama-3.1-70b-versatile"
llm = ChatGroq(cache=False, temperature=0.0, model_name=llm_name)

class AgentState(TypedDict):
    requirements: str
    email: str

def write_email(state: AgentState) -> AgentState:
    # Hardcoded tasks
    tasks = "1. Analyze requirements\n2. Develop solution\n3. Test and deploy"
    
    # Simplified email writing
    email = f"""
    Dear Client,

    Based on your requirements: {state['requirements']}

    We propose the following tasks:
    {tasks}

    Best regards,
    AI Team
    """
    return {"email": email}

# Initialize the graph
graph = StateGraph(AgentState)
graph.add_node("write_email", write_email)
graph.add_edge("write_email", END)
graph.set_entry_point("write_email")

workflow = graph.compile()

# Function to run the workflow
def generate_email(requirements: str) -> str:
    result = workflow.invoke({"requirements": requirements})
    return result["email"]

requirements = "Analyze customer data and provide recommendations"
email = generate_email(requirements)
print(email)

In [44]:
list(graph.nodes.keys())[0]

'write_email'

In [46]:
# get source code of the function
print(inspect.getsource(graph.nodes[list(graph.nodes.keys())[0]].func))

In [31]:
inspect.getsource(graph.nodes["write_email"].func)

OSError: could not get source code

In [13]:
print(optimizer.node_functions['write_email'])

In [27]:
from rich import print
import re
from typing import List, Dict, Any
from langgraph.graph import StateGraph
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq

def escape_curly_braces(text: str) -> str:
    text = str(text)
    return text.replace("{", "{{").replace("}", "}}")


class NodeFunctionUpdater:
    def __init__(self, llm):
        self.llm = llm

    def update(self, node_name: str, original_code: str, score: float, feedback: str) -> str:
        print(f"Updating node: {node_name}")
        print(f"Score: {score}")
        print(f"Feedback: {feedback}")
        prompt_template = f"""
        You are an expert in optimizing Python code. Improve the following script, focusing on the '{node_name}' function.
        The improved function should use an LLM for generating responses.

        Original code:
        ```python
        {original_code}
        ```
        Performance score: {score}
        Feedback: {feedback}

        Please rewrite the entire script, making sure to:
        1. Keep all necessary imports.
        2. Improve the '{node_name}' function to use an LLM for generating responses.
        3. Maintain the overall structure of the script, including the graph setup and workflow compilation.
        4. Ensure that the improved script is fully functional and can be executed as is.

        Return the entire improved Python script, without any additional explanation.
        """

        prompt_template = escape_curly_braces(prompt_template)
        
        prompt = ChatPromptTemplate.from_template(prompt_template)
        chain = prompt | self.llm
        response = chain.invoke({
            "original_code": original_code,
            "score": score,
            "feedback": feedback
        })
        print("LLM response content:")
        print(response.content)
        return response.content

class GraphAgentOptimizer:
    def __init__(self, original_code: str, ground_truth: List[Dict[str, Any]], llm):
        self.original_code = original_code
        self.ground_truth = ground_truth
        self.llm = llm
        self.node_performances = {}
        self.updater = NodeFunctionUpdater(llm)
        self.optimized_code = original_code

    def load_graph(self):
        print("Loading graph...")
        function_pattern = re.compile(r'def\s+(\w+)\s*\(')
        self.node_names = function_pattern.findall(self.original_code)
        for node_name in self.node_names:
            self.node_performances[node_name] = {'score': 0, 'feedback': ''}
        print(f"Loaded {len(self.node_names)} nodes: {', '.join(self.node_names)}")

    def optimize(self, num_iterations: int):
        for iteration in range(num_iterations):
            print(f"\nIteration {iteration + 1}/{num_iterations}")
            for i, example in enumerate(self.ground_truth):
                print(f"  Processing example {i + 1}/{len(self.ground_truth)}")
                try:
                    output = self.forward_pass(example['input'])
                    self.evaluate_and_update(output, example['output'])
                    # print generated output
                    print(f"Generated output: {output}")
                except Exception as e:
                    print(f"Error processing example {i + 1}: {str(e)}")
            self.update_code()

    def forward_pass(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
        print("Performing forward pass...")
        locals_dict = {}
        exec(self.optimized_code, globals(), locals_dict)
        workflow = locals_dict['workflow']
        return workflow.invoke(input_data)

    def evaluate_and_update(self, generated_output: Dict[str, Any], ground_truth_output: Dict[str, Any]):
        print("Evaluating and updating node performances...")
        prompt = ChatPromptTemplate.from_template("""
        Compare the generated output with the ground truth output.
        Provide a score from 0 to 1 for each function in the script,
        where 1 is perfect and 0 is completely wrong.
        Also provide a brief explanation for each score.

        Generated output:
        {generated_output}

        Ground truth output:
        {ground_truth_output}

        Functions to evaluate: {node_names}

        Response format:
        [function_name] score: [score]
        [function_name] feedback: [feedback]
        (Repeat for each function)
        """)
        chain = prompt | self.llm
        response = chain.invoke({
            "generated_output": str(generated_output),
            "ground_truth_output": str(ground_truth_output),
            "node_names": ", ".join(self.node_names)
        })

        print("Raw LLM response:")
        print(response.content)
        
        self.parse_llm_response(response.content)

    def parse_llm_response(self, response: str):
        lines = response.split('\n')
        current_node = None
        for line in lines:
            line = line.strip()
            if not line:
                continue
            
            score_match = re.match(r'(\w+)\s+score:\s*([\d.]+)', line, re.IGNORECASE)
            if score_match:
                current_node = score_match.group(1)
                score = float(score_match.group(2))
                if current_node in self.node_performances:
                    self.node_performances[current_node]['score'] = (self.node_performances[current_node]['score'] + score) / 2
                    print(f"Updated score for {current_node}: {score}")
                else:
                    print(f"Warning: Unexpected node {current_node}")
            elif current_node and 'feedback' in line.lower():
                feedback = line.split(':', 1)[1].strip() if ':' in line else line
                if current_node in self.node_performances:
                    self.node_performances[current_node]['feedback'] += feedback + '\n'
                    print(f"Updated feedback for {current_node}: {feedback}")

    def update_code(self):
        print("Updating code...")
        for node_name, performance in self.node_performances.items():
            if performance['score'] < 0.8:
                try:
                    self.optimized_code = self.updater.update(
                        node_name,
                        self.optimized_code,
                        performance['score'],
                        performance['feedback']
                    )
                    print(f"Updated code for {node_name}")
                except Exception as e:
                    print(f"Error updating code for {node_name}: {str(e)}")
            else:
                print(f"Node {node_name} performance is good (score: {performance['score']}). Skipping update.")

    def get_optimized_code(self) -> str:
        return self.optimized_code

# Usage
llm_name = "llama-3.1-8b-instant"
# llm_name = "llama-3.1-70b-versatile"
llm = ChatGroq(cache=False, temperature=0.0, model_name=llm_name)

original_code = """
from typing import TypedDict
from langgraph.graph import StateGraph, END
from dotenv import load_dotenv
from langchain_groq import ChatGroq
load_dotenv("env")
llm_name = "llama-3.1-8b-instant"
llm = ChatGroq(cache=False, temperature=0.0, model_name=llm_name)

class AgentState(TypedDict):
    requirements: str
    email: str

def write_email(state: AgentState) -> AgentState:
    # Hardcoded tasks
    tasks = "1. Analyze requirements\\n2. Develop solution\\n3. Test and deploy"
    
    # Simplified email writing
    email = f'''
    Dear Client,
    Based on your requirements: {state['requirements']}
    We propose the following tasks:
    {{tasks}}
    Best regards,
    AI Team
    '''
    return {"email": email}

# Initialize the graph
graph = StateGraph(AgentState)
graph.add_node("write_email", write_email)
graph.add_edge("write_email", END)
graph.set_entry_point("write_email")
workflow = graph.compile()

# Function to run the workflow
def generate_email(requirements: str) -> str:
    result = workflow.invoke({"requirements": requirements})
    return result["email"]

requirements = "Analyze customer data and provide recommendations"
email = generate_email(requirements)
print(email)
"""

ground_truth = [
    {
        "input": {"requirements": "Develop a mobile app for task management"},
        "output": {
            "email": """
            Dear Client,
            Based on your requirements: Develop a mobile app for task management
            We propose the following tasks:
            1. Analyze user requirements and define app features
            2. Design intuitive UI/UX for task management
            3. Develop core functionality (task creation, editing, deletion)
            4. Implement data synchronization and cloud storage
            5. Conduct thorough testing on multiple devices
            6. Deploy to app stores and gather user feedback
            Best regards,
            AI Team
            """
        }
    },
    {
        "input": {"requirements": "Create an e-commerce website with payment integration"},
        "output": {
            "email": """
            Dear Client,
            Based on your requirements: Create an e-commerce website with payment integration
            We propose the following tasks:
            1. Analyze business requirements and select appropriate e-commerce platform
            2. Design responsive and user-friendly website layout
            3. Set up product catalog and inventory management system
            4. Implement secure payment gateway integration
            5. Develop order processing and shipment tracking features
            6. Conduct thorough security and performance testing
            7. Launch website and provide post-launch support
            Best regards,
            AI Team
            """
        }
    }
]

optimizer = GraphAgentOptimizer(original_code, ground_truth, llm)
optimizer.load_graph()
optimizer.optimize(num_iterations=3)
optimized_code = optimizer.get_optimized_code()

print("Optimized code:")
print(optimized_code)

# Test the optimized code
# exec(optimized_code, globals())
test_input = {"requirements": "Design a customer loyalty program for a retail chain"}
result = generate_email(test_input["requirements"])
print("Optimized graph result:")
print(result)

In [47]:
import re
from typing import List, Dict, Any
from langgraph.graph import StateGraph
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq

def escape_curly_braces(text: str) -> str:
    text = str(text)
    return text.replace("{", "{{").replace("}", "}}")

def replace_function_in_source(original_source: str, old_function_name: str, new_function_code: str) -> str:
    pattern = re.compile(rf"def {old_function_name}\(.*?\):.*?(\n    .*)*?\n", re.DOTALL)
    modified_source = re.sub(pattern, new_function_code, original_source)
    return modified_source

class NodeFunctionUpdater:
    def __init__(self, llm):
        self.llm = llm

    def update(self, node_name: str, original_code: str, score: float, feedback: str) -> str:
        print(f"Updating node: {node_name}")
        print(f"Score: {score}")
        print(f"Feedback: {feedback}")
        prompt_template = f"""
        You are an expert in optimizing Python code. Improve the following function, focusing on the '{node_name}' function.
        The improved function should use an LLM for generating responses.

        Original code:
        ```python
        {original_code}
        ```
        Performance score: {score}
        Feedback: {feedback}

        Please rewrite only the '{node_name}' function, making sure to:
        1. Improve the function to use an LLM for generating responses.
        2. Ensure that the improved function is fully functional and can be executed as is.
        3. Maintain the same function signature and return type.

        Return only the improved Python function, without any additional explanation.
        """

        prompt_template = escape_curly_braces(prompt_template)
        
        prompt = ChatPromptTemplate.from_template(prompt_template)
        chain = prompt | self.llm
        response = chain.invoke({
            "original_code": original_code,
            "score": score,
            "feedback": feedback
        })
        print("LLM response content:")
        print(response.content)
        return response.content

class GraphAgentOptimizer:
    def __init__(self, original_code: str, ground_truth: List[Dict[str, Any]], llm):
        self.original_code = original_code
        self.ground_truth = ground_truth
        self.llm = llm
        self.node_performances = {}
        self.updater = NodeFunctionUpdater(llm)
        self.optimized_code = original_code

    def load_graph(self):
        print("Loading graph...")
        function_pattern = re.compile(r'def\s+(\w+)\s*\(')
        self.node_names = function_pattern.findall(self.original_code)
        for node_name in self.node_names:
            self.node_performances[node_name] = {'score': 0, 'feedback': ''}
        print(f"Loaded {len(self.node_names)} nodes: {', '.join(self.node_names)}")

    def optimize(self, num_iterations: int):
        for iteration in range(num_iterations):
            print(f"\nIteration {iteration + 1}/{num_iterations}")
            for i, example in enumerate(self.ground_truth):
                print(f"  Processing example {i + 1}/{len(self.ground_truth)}")
                try:
                    output = self.forward_pass(example['input'])
                    self.evaluate_and_update(output, example['output'])
                    print(f"Generated output: {output}")
                except Exception as e:
                    print(f"Error processing example {i + 1}: {str(e)}")
            self.update_code()

    def forward_pass(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
        print("Performing forward pass...")
        locals_dict = {}
        exec(self.optimized_code, globals(), locals_dict)
        workflow = locals_dict['workflow']
        return workflow.invoke(input_data)

    def evaluate_and_update(self, generated_output: Dict[str, Any], ground_truth_output: Dict[str, Any]):
        print("Evaluating and updating node performances...")
        prompt = ChatPromptTemplate.from_template("""
        Compare the generated output with the ground truth output.
        Provide a score from 0 to 1 for each function in the script,
        where 1 is perfect and 0 is completely wrong.
        Also provide a brief explanation for each score.

        Generated output:
        {generated_output}

        Ground truth output:
        {ground_truth_output}

        Functions to evaluate: {node_names}

        Response format:
        [function_name] score: [score]
        [function_name] feedback: [feedback]
        (Repeat for each function)
        """)
        chain = prompt | self.llm
        response = chain.invoke({
            "generated_output": str(generated_output),
            "ground_truth_output": str(ground_truth_output),
            "node_names": ", ".join(self.node_names)
        })

        print("Raw LLM response:")
        print(response.content)
        
        self.parse_llm_response(response.content)

    def parse_llm_response(self, response: str):
        lines = response.split('\n')
        current_node = None
        for line in lines:
            line = line.strip()
            if not line:
                continue
            
            score_match = re.match(r'(\w+)\s+score:\s*([\d.]+)', line, re.IGNORECASE)
            if score_match:
                current_node = score_match.group(1)
                score = float(score_match.group(2))
                if current_node in self.node_performances:
                    self.node_performances[current_node]['score'] = (self.node_performances[current_node]['score'] + score) / 2
                    print(f"Updated score for {current_node}: {score}")
                else:
                    print(f"Warning: Unexpected node {current_node}")
            elif current_node and 'feedback' in line.lower():
                feedback = line.split(':', 1)[1].strip() if ':' in line else line
                if current_node in self.node_performances:
                    self.node_performances[current_node]['feedback'] += feedback + '\n'
                    print(f"Updated feedback for {current_node}: {feedback}")

    def update_code(self):
        print("Updating code...")
        for node_name, performance in self.node_performances.items():
            if performance['score'] < 0.8:
                try:
                    new_function_code = self.updater.update(
                        node_name,
                        self.optimized_code,
                        performance['score'],
                        performance['feedback']
                    )
                    self.optimized_code = replace_function_in_source(self.optimized_code, node_name, new_function_code)
                    print(f"Updated code for {node_name}")
                except Exception as e:
                    print(f"Error updating code for {node_name}: {str(e)}")
            else:
                print(f"Node {node_name} performance is good (score: {performance['score']}). Skipping update.")

    def get_optimized_code(self) -> str:
        return self.optimized_code

# Usage example
llm_name = "llama-3.1-8b-instant"
llm = ChatGroq(cache=False, temperature=0.0, model_name=llm_name)

original_code = """
from typing import TypedDict
from langgraph.graph import StateGraph, END
from dotenv import load_dotenv
from langchain_groq import ChatGroq
load_dotenv("env")
llm_name = "llama-3.1-8b-instant"
llm = ChatGroq(cache=False, temperature=0.0, model_name=llm_name)

class AgentState(TypedDict):
    requirements: str
    email: str

def write_email(state: AgentState) -> AgentState:
    # Hardcoded tasks
    tasks = "1. Analyze requirements\\n2. Develop solution\\n3. Test and deploy"
    
    # Simplified email writing
    email = f'''
    Dear Client,
    Based on your requirements: {state['requirements']}
    We propose the following tasks:
    {{tasks}}
    Best regards,
    AI Team
    '''
    return {"email": email}

# Initialize the graph
graph = StateGraph(AgentState)
graph.add_node("write_email", write_email)
graph.add_edge("write_email", END)
graph.set_entry_point("write_email")
workflow = graph.compile()

# Function to run the workflow
def generate_email(requirements: str) -> str:
    result = workflow.invoke({"requirements": requirements})
    return result["email"]

requirements = "Analyze customer data and provide recommendations"
email = generate_email(requirements)
print(email)
"""

ground_truth = [
    {
        "input": {"requirements": "Develop a mobile app for task management"},
        "output": {
            "email": """
            Dear Client,
            Based on your requirements: Develop a mobile app for task management
            We propose the following tasks:
            1. Analyze user requirements and define app features
            2. Design intuitive UI/UX for task management
            3. Develop core functionality (task creation, editing, deletion)
            4. Implement data synchronization and cloud storage
            5. Conduct thorough testing on multiple devices
            6. Deploy to app stores and gather user feedback
            Best regards,
            AI Team
            """
        }
    },
    {
        "input": {"requirements": "Create an e-commerce website with payment integration"},
        "output": {
            "email": """
            Dear Client,
            Based on your requirements: Create an e-commerce website with payment integration
            We propose the following tasks:
            1. Analyze business requirements and select appropriate e-commerce platform
            2. Design responsive and user-friendly website layout
            3. Set up product catalog and inventory management system
            4. Implement secure payment gateway integration
            5. Develop order processing and shipment tracking features
            6. Conduct thorough security and performance testing
            7. Launch website and provide post-launch support
            Best regards,
            AI Team
            """
        }
    }
]

optimizer = GraphAgentOptimizer(original_code, ground_truth, llm)
optimizer.load_graph()
optimizer.optimize(num_iterations=3)
optimized_code = optimizer.get_optimized_code()

print("Optimized code:")
print(optimized_code)

# Test the optimized code
exec(optimized_code, globals())
test_input = {"requirements": "Design a customer loyalty program for a retail chain"}
result = generate_email(test_input["requirements"])
print("Optimized graph result:")
print(result)