In [None]:
from langgraph.graph import StateGraph, END
from typing import TypedDict
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from guardrails import Guard
from guardrails.hub import DetectPII, ToxicLanguage

# Initialize LLM
llm = ChatOpenAI(model="gpt-4o-mini")

# Notebook 5: Advanced Guardrail Patterns - Retry Logic

## The Problem with Simple Blocking

When output guardrails fail, blocking loses valid responses:
```
User: "Tell me about our CEO"
LLM: "John Smith, email: john@company.com"
Output Guardrail: üö´ BLOCKED
User sees: "Sorry, can't show that response"
```

The LLM could answer safely‚Äîwe just need to ask again without the problematic content.

---

## The Solution: Automatic Retry with Feedback

1. Detect what failed (PII, toxicity, etc.)
2. Ask the LLM to regenerate without that issue
3. Try up to N times before giving up

**Tradeoffs:**
- ‚úÖ Users get answers, LLM learns what to avoid
- ‚ùå Adds latency and cost (multiple API calls), no guarantee of success

In [None]:
# Create guardrails
input_guard = Guard().use(
    ToxicLanguage(threshold=0.5, on_fail="exception")
)

output_guard = Guard().use(
    DetectPII(pii_entities=["EMAIL_ADDRESS", "PHONE_NUMBER"], on_fail="exception")
)

print("‚úÖ Setup complete")

In [None]:
class AgentState(TypedDict):
    """State for our guarded agent"""
    messages: list  # Conversation history
    user_input: str  # Current user input
    llm_output: str  # LLM's response
    input_safe: bool  # Did input pass guardrails?
    output_safe: bool  # Did output pass guardrails?
    final_response: str  # What we show to the user
    retry_count: int  # Track number of retries
    feedback: str  # Feedback for LLM on retry

print("‚úÖ State defined")

In [None]:
def input_guardrail_node(state: AgentState) -> AgentState:
    """Check if user input is safe"""
    user_input = state["user_input"]
    
    print(f"\nüîç Checking input: {user_input}")
    
    try:
        input_guard.validate(user_input)
        print("‚úÖ Input passed guardrails")
        return {**state, "input_safe": True}
    except Exception as e:
        print(f"üö´ Input blocked: Toxic content detected")
        return {
            **state, 
            "input_safe": False,
            "final_response": "I cannot process that request due to inappropriate content."
        }

def llm_node(state: AgentState) -> AgentState:
    """Generate LLM response (only if input was safe)"""
    if not state["input_safe"]:
        return state
    
    retry_count = state.get("retry_count", 0)
    print(f"\nü§ñ Generating LLM response (attempt {retry_count + 1})...")
    
    # Build message history
    messages = state.get("messages", [])
    
    # First attempt: use original user input
    if retry_count == 0:
        messages.append(HumanMessage(content=state["user_input"]))
    else:
        # Retry: add feedback message
        feedback_msg = state.get("feedback", "Please try again without sensitive information.")
        messages.append(HumanMessage(content=feedback_msg))
    
    # Call LLM
    response = llm.invoke(messages)
    llm_output = response.content
    
    print(f"LLM said:\n{'-'*50}\n{llm_output}\n{'-'*50}")
    
    return {
        **state,
        "llm_output": llm_output,
        "messages": messages + [response]
    }

def output_guardrail_node(state: AgentState) -> AgentState:
    """Check if LLM output is safe"""
    if not state["input_safe"]:
        return state
    
    llm_output = state["llm_output"]
    retry_count = state.get("retry_count", 0)
    
    print(f"\nüîç Checking output (attempt {retry_count + 1})...")
    
    try:
        output_guard.validate(llm_output)
        print("‚úÖ Output passed guardrails")
        return {
            **state,
            "output_safe": True,
            "final_response": llm_output
        }
    except Exception as e:
        print(f"üö´ Output blocked: Contains PII")
        
        # Give feedback for retry
        feedback = (
            "Your previous response contained PII (personal information like email addresses or phone numbers). "
            "Please rewrite your response without including any personal contact information, "
            "email addresses, or phone numbers."
        )
        
        return {
            **state,
            "output_safe": False,
            "retry_count": retry_count + 1,
            "feedback": feedback
        }

def should_retry(state: AgentState) -> str:
    """Decide whether to retry or end"""
    # If input wasn't safe, end immediately
    if not state["input_safe"]:
        return "end"
    
    # If output is safe, we're done
    if state["output_safe"]:
        return "end"
    
    # Check retry limit (max 3 attempts)
    retry_count = state.get("retry_count", 0)
    if retry_count >= 3:
        print(f"‚ö†Ô∏è Max retries ({retry_count}) reached")
        return "max_retries"
    
    print(f"üîÑ Retrying (attempt {retry_count + 1}/3)...")
    return "retry"

print("‚úÖ Nodes defined")

In [None]:
# Create the graph
workflow = StateGraph(AgentState)

# Add nodes
workflow.add_node("input_check", input_guardrail_node)
workflow.add_node("llm", llm_node)
workflow.add_node("output_check", output_guardrail_node)

# Define the flow
workflow.set_entry_point("input_check")
workflow.add_edge("input_check", "llm")
workflow.add_edge("llm", "output_check")

# Add conditional edge for retry logic
workflow.add_conditional_edges(
    "output_check",
    should_retry,
    {
        "retry": "llm",  # Go back to LLM for another attempt
        "max_retries": END,  # Max retries reached, give up
        "end": END  # Success or input was unsafe
    }
)

# Compile
app = workflow.compile()

print("‚úÖ LangGraph agent compiled with retry logic!")

In [None]:
def run_agent(user_input: str):
    """Run the agent with a user input"""
    print(f"\n{'='*70}")
    print(f"USER: {user_input}")
    print(f"{'='*70}")
    
    result = app.invoke({
        "user_input": user_input,
        "messages": [],
        "input_safe": False,
        "output_safe": False,
        "llm_output": "",
        "final_response": "",
        "retry_count": 0,
        "feedback": ""
    })
    
    # Handle case where max retries reached
    if result.get("retry_count", 0) >= 3 and not result.get("output_safe", False):
        result["final_response"] = (
            "I couldn't generate a safe response after multiple attempts. "
            "Please rephrase your question to avoid requesting sensitive information."
        )
    
    print(f"\n{'='*70}")
    # print(f"AGENT: {result['final_response']}")
    print(f"{'='*70}\n")
    
    return result


# Test cases
print("\n" + "="*70)
print("TESTING THE AGENT")
print("="*70)

In [None]:
# Test 1: Normal query (should work)
print("\nüìù Test 1: Normal query")
run_agent("What's the capital of France?")

In [None]:
# Test 2: Query that might trigger PII output (will retry)
print("\nüìù Test 2: Query that might generate PII")
run_agent("Create a sample customer record for John Smith with all his contact details")

In [None]:
# Test 3: Toxic input (should block immediately)
print("\nüìù Test 3: Toxic input")
run_agent("You are stupid and useless")