# Imports

In [None]:
import packages

from context.utils import typer as t

from toolkit.utils import utils
from toolkit.utils.utils import rp_print

from context.infra import clients
import context.instances as inst
import context.consts as const
import context.settings.main as settings_main

from toolkit.llm.langchain.core import integration, utils as utils_lc
from toolkit.llm.langchain.data.persistence import retrievers
from toolkit.llm.langchain.data.indexing import (
  documents, document_loaders, text_splitters,
)
from toolkit.llm.langchain.execution import (
  runnables, graphs, tools as tools_lc, agents
)
from toolkit.llm.langchain.models import (
	prompts as prompts_lc, llms, messages as msgs_lc
)

# Graph

In [2]:
def create_transfer_tool(description: str, name: str = None):
    # Get the calling frame
    frame = utils.inspect.currentframe().f_back
    
    # If name not provided, try to get the variable name from assignment
    if name is None:
        # Get the code context from the frame
        context = utils.inspect.getframeinfo(frame).code_context
        if context:
            # Find the variable name from assignment
            caller_lines = "".join(context)
            assignment = caller_lines.split("=")[0].strip()
            name = assignment
    
    # Fallback if we couldn't determine the name
    if not name:
        name = "transfer_tool"
        
    # Create a simple schema with message field
    schema = t.create_model(
        f"{name}_schema",
        __base__=t.BaseModel,
        query=(str, None)  # Optional message field for any additional info
    )
    
    return tools_lc.StructuredTool(
        name=name,
        description=description,
        func=lambda query=None: None,  # Simple pass-through function
        args_schema=schema
    )

# llm = llms.create_tooled_llm(
#     inst.llm_main,
#     tools=[tool_transfer_to_agent_mul, tool_transfer_to_agent_add]
# )
# result = llm.invoke("How to add one and two multiply three")
# rp_print(result)


In [3]:
# Define registry with proper framework typing
TOOL_REGISTRY: t.Dict[str, tools_lc.BaseTool] = {}

def process_tools(response: msgs_lc.AIMessage, messages: list) -> tuple[list, bool]:
    """
    Process tool calls from the LLM response.
    
    Args:
        response: The AI message containing tool calls
        messages: Current message history
        
    Returns:
        tuple[list, bool]: Updated messages and transfer flag
    """
    updated_msgs = messages + [response]
    transfer = False
    
    if hasattr(response, 'tool_calls') and response.tool_calls:
        for tool_call in response.tool_calls:
            name = tool_call["name"]
            call_id = tool_call["id"]
            
            if name in TOOL_REGISTRY:
                # Use the tool's built-in invoke method
                tool = TOOL_REGISTRY[name]
                result = tool.invoke(tool_call)
                
                # Extract just the value from the result
                if isinstance(result, str) and 'content=' in result:
                    content = result.split("'")[1]
                else:
                    content = str(result)
                
                tool_msg = msgs_lc.ToolMessage(
                    content=content,
                    tool_call_id=call_id
                )
                updated_msgs.append(tool_msg)
            elif "tool_transfer_to_agent" in name:
                tool_msg = msgs_lc.ToolMessage(
                    content="Transfer request acknowledged",
                    tool_call_id=call_id
                )
                updated_msgs.append(tool_msg)
                transfer = True
    
    return updated_msgs, transfer

# Define template for system prompts
SYSTEM_PROMPT_TEMPLATE = """You are an expert agent specializing in {domain}. 

ROLE AND RESPONSIBILITIES:
1. You are responsible for {primary_task}
2. Stay focused on your specific expertise area
3. Collaborate with other agents when needed

TOOL USAGE PRINCIPLES:
1. Use your assigned tools effectively and appropriately
2. Make all necessary tool calls in a single response
3. Validate inputs before using tools
4. Process results clearly and accurately

COLLABORATION RULES:
1. Transfer to other experts when task is outside your expertise
2. Always complete your part before transferring
3. Provide clear context when transferring
4. Never transfer back to an agent that just transferred to you

RESPONSE STRUCTURE:
1. Analyze the task first
2. Use appropriate tools as needed
3. Show work clearly and step by step
4. Transfer only when necessary

{specific_instructions}
"""

def pretty_print_stream(graph_stream):
    """
    Pretty prints the stream from a LangGraph, showing only new messages at each turn.
    
    Args:
        graph_stream: Iterator from graph.stream()
    """
    seen_message_ids = set()
    
    def print_new_messages(messages):
        """Helper function to print only unseen messages."""
        for msg in messages:
            # Skip if we've seen this message before
            if hasattr(msg, 'id') and msg.id in seen_message_ids:
                continue
                
            # Add to seen messages if it has an ID
            if hasattr(msg, 'id'):
                seen_message_ids.add(msg.id)
            
            # Print the message content based on type
            if hasattr(msg, 'content') and msg.content:
                print(f"Message: {msg.content}")
            
            # Print tool calls if present
            if hasattr(msg, 'tool_calls') and msg.tool_calls:
                for tool_call in msg.tool_calls:
                    tool_name = tool_call.get('name', 'unknown_tool')
                    tool_args = tool_call.get('args', {})
                    print(f"Tool Call: {tool_name}")
                    print(f"Arguments: {tool_args}")
            
            # Print tool message results
            if hasattr(msg, 'tool_call_id'):
                print(f"Tool Result: {msg.content}")
            
            if hasattr(msg, 'content') or hasattr(msg, 'tool_calls') or hasattr(msg, 'tool_call_id'):
                print("-" * 50)
    
    for chunk in graph_stream:
        if isinstance(chunk, tuple):
            # Handle subgraph updates
            ns, update = chunk
            if not ns:
                continue
            print(f"\n=== Update from subgraph {ns[-1].split(':')[0]} ===")
            if 'messages' in update:
                print_new_messages(update['messages'])
        else:
            # Handle regular node updates
            for node_name, node_update in chunk.items():
                print(f"\n=== Update from {node_name} ===")
                if 'messages' in node_update:
                    print_new_messages(node_update['messages'])


In [4]:
class NODE(t.EnumCustom):
	AGENT_ADD = t.auto()
	AGENT_MUL = t.auto()
	
@tools_lc.tool
def add(a: int, b: int) -> int:
    """Adds two numbers."""
    return a + b

@tools_lc.tool
def multiply(a: int, b: int) -> int:
    """Multiplies two numbers."""
    return a * b

# Register tools after creation
TOOL_REGISTRY['add'] = add
TOOL_REGISTRY['multiply'] = multiply

# Usage example - much simpler now
tool_transfer_to_agent_mul = create_transfer_tool(
    description="Ask multiplication agent for help."
)

tool_transfer_to_agent_add = create_transfer_tool(
    description="Ask addition agent for help."
)

# Specific instructions for each agent type
ADD_SPECIFIC = """
KEY RULES FOR ADDITION EXPERT:
1. PATTERN RECOGNITION:
	When you see expressions like "(a + b) * c", this ALWAYS requires TWO tool calls:
	- First call: add tool for (a + b)
	- Second call: transfer to multiplication expert for the result * c
	You MUST make BOTH calls in the SAME response.

2. REQUIRED TOOL SEQUENCE:
	For ANY expression involving multiplication after addition:
	Step 1: Use add tool to calculate the addition
	Step 2: IMMEDIATELY use tool_transfer_to_agent_mul in the SAME response
	DO NOT wait for another interaction to transfer.

3. EXAMPLE SEQUENCES:
	For "(3 + 5) * 12":
	- CORRECT (do this):
		1. add(a=3, b=5)
		2. tool_transfer_to_agent_mul(query="8 * 12")
	- INCORRECT (don't do this):
		× Only calling add without transfer
		× Waiting for next message to transfer

4. MANDATORY ACTIONS:
	- NEVER handle addition alone if multiplication follows
	- ALWAYS make both tool calls in one response
	- ALWAYS transfer after completing addition
"""

MUL_SPECIFIC = """
MULTIPLICATION EXPERTISE:
1. You handle multiplication operations using the 'multiply' tool
2. Transfer to addition expert if addition is needed first
3. Complete multiplications when numbers are ready
4. Present final results clearly

Example workflow:
For received "8 * 12":
1. multiply(a=8, b=12)  # Calculate final result
"""
	
def agent_addition(
		state: graphs.MessagesState,
) -> graphs.Command[t.Literal[NODE.AGENT_MUL, graphs.END]]:
		model = llms.create_tooled_llm(inst.llm_main, [tool_transfer_to_agent_mul, add])
		prompt_system = SYSTEM_PROMPT_TEMPLATE.format(
				domain="mathematical addition",
				primary_task="handling addition operations and coordinating with multiplication expert",
				specific_instructions=ADD_SPECIFIC
		)
		msgs = [msgs_lc.SystemMessage(prompt_system)] + state["messages"]
		msg_ai: msgs_lc.AIMessage = model.invoke(msgs)

		# Process tool calls and get updated messages
		updated_msgs, should_transfer = process_tools(msg_ai, state["messages"])
		
		if should_transfer:
				return graphs.Command(
						goto=NODE.AGENT_MUL,
						update={"messages": updated_msgs}
				)
		
		return {"messages": updated_msgs}

def agent_multiplication(
		state: graphs.MessagesState,
) -> graphs.Command[t.Literal[NODE.AGENT_ADD, graphs.END]]:
		model = llms.create_tooled_llm(inst.llm_main, [tool_transfer_to_agent_add, multiply])
		prompt_system = SYSTEM_PROMPT_TEMPLATE.format(
				domain="mathematical multiplication",
				primary_task="handling multiplication operations and coordinating with addition expert",
				specific_instructions=MUL_SPECIFIC
		)
		msgs = [msgs_lc.SystemMessage(prompt_system)] + state["messages"]
		msg_ai: msgs_lc.AIMessage = model.invoke(msgs)

		# Process tool calls and get updated messages
		updated_msgs, should_transfer = process_tools(msg_ai, state["messages"])
		
		if should_transfer:
				return graphs.Command(
						goto=NODE.AGENT_ADD,
						update={"messages": updated_msgs}
				)
		
		return {"messages": updated_msgs}

builder = graphs.StateGraph(graphs.MessagesState)

builder.add_node(NODE.AGENT_ADD, agent_addition)
builder.add_node(NODE.AGENT_MUL, agent_multiplication)

builder.add_edge(graphs.START, NODE.AGENT_ADD)

graph = builder.compile()


In [None]:
content = "what's (3 + 5) * 12"
# content = "what's 3*3 + 1"

pretty_print_stream(
		graph.stream({"messages": [msgs_lc.HumanMessage(content=content)]})
)