# Decorator-based Middlewares for HR Agents - LangChain 1.0 

**Module:** Decorator-based Middleware Patterns

**What you'll learn:**
- 🎯 `@before_agent` - Before agent starts (once per invocation)
- 🎯 `@before_model` - Before each model call  
- 🎯 `@after_model` - After each model response
- 🎯 `@after_agent` - After agent completes (once per invocation)

**Key Fix:**
- ✅ Uses `middleware` parameter instead of `pre_model_hook`/`post_model_hook`
- ✅ Proper decorator imports from `langchain.agents.middleware`
- ✅ Correct function signatures with `Runtime` parameter

---

## Setup: Install Dependencies

In [None]:
!pip install --pre -U langchain langchain-openai langgraph
!pip install langgraph-checkpoint-sqlite

In [None]:
# Imports
from google.colab import userdata
import os

os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')

# CORRECTED IMPORTS
from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import before_model, after_model, before_agent, after_agent
from langchain_openai import ChatOpenAI
from langchain_core.tools import tool
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.runtime import Runtime
from typing import Annotated, Any
from datetime import datetime
from functools import wraps
import time
import json

print("✅ Setup complete!")

## Setup: HR Data and Tools

In [None]:
# Employee database
EMPLOYEES = {
    "101": {"name": "Priya Sharma", "department": "Engineering", "role": "Senior Developer", "salary": 120000},
    "102": {"name": "Rahul Verma", "department": "Engineering", "role": "Manager", "salary": 180000},
    "103": {"name": "Anjali Patel", "department": "HR", "role": "HR Director", "salary": 200000},
    "104": {"name": "Arjun Reddy", "department": "Sales", "role": "Team Lead", "salary": 150000},
    "105": {"name": "Sneha Gupta", "department": "Marketing", "role": "Specialist", "salary": 110000}
}

@tool
def get_employee_info(employee_id: Annotated[str, "Employee ID"]) -> str:
    """Get employee information."""
    if employee_id in EMPLOYEES:
        emp = EMPLOYEES[employee_id]
        return f"{emp['name']} - {emp['department']} - {emp['role']}"
    return f"Employee {employee_id} not found"

@tool
def check_salary(employee_id: Annotated[str, "Employee ID"]) -> str:
    """Check employee salary. SENSITIVE."""
    if employee_id in EMPLOYEES:
        return f"Salary: ₹{EMPLOYEES[employee_id]['salary']:,}"
    return "Not found"

@tool
def update_employee_record(employee_id: Annotated[str, "Employee ID"], field: Annotated[str, "Field"], value: Annotated[str, "Value"]) -> str:
    """Update employee record."""
    if employee_id in EMPLOYEES and field in EMPLOYEES[employee_id]:
        old_value = EMPLOYEES[employee_id][field]
        EMPLOYEES[employee_id][field] = value
        return f"Updated {field}: {old_value} → {value}"
    return "Update failed"

print(f"✅ Loaded {len(EMPLOYEES)} employees")

---
# Part 1: Session-Level Decorators (CORRECTED)

## @before_agent and @after_agent

**Key Changes:**
- ✅ Import from `langchain.agents.middleware`
- ✅ Add `runtime: Runtime` parameter
- ✅ Return `dict[str, Any] | None` instead of just `dict`

---

## Lab 1.1: @before_agent - Session Initialization

In [None]:
# CORRECTED: Use official decorator from langchain
@before_agent
def initialize_hr_session(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """Initialize HR consultation session."""
    if not state.get("session_id"):
        session_id = f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        user_id = state.get("current_user_id", "unknown")
        
        print(f"\n🚀 [@before_agent] initialize_hr_session")
        print(f"   📋 Session ID: {session_id}")
        print(f"   👤 User: {user_id}")
        print(f"   🕐 Started: {datetime.now().isoformat()}")
        
        return {
            "session_id": session_id,
            "session_start": datetime.now().isoformat(),
            "interaction_count": 0
        }
    return None

@before_agent
def verify_user_authentication(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """Verify user is authenticated."""
    user_id = state.get("current_user_id")
    
    print(f"\n🚀 [@before_agent] verify_user_authentication")
    
    if not user_id or user_id not in EMPLOYEES:
        print(f"   ❌ Authentication failed for: {user_id}")
        return {
            "authenticated": False,
            "messages": [{"role": "assistant", "content": "❌ Authentication required. Please log in."}],
            "jump_to": "__end__"
        }
    
    print(f"   ✅ Authenticated: {EMPLOYEES[user_id]['name']}")
    return {"authenticated": True}

print("✅ @before_agent decorators defined!")

## Lab 1.2: @after_agent - Session Cleanup

In [None]:
@after_agent
def log_session_summary(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """Log session summary after completion."""
    print(f"\n🏁 [@after_agent] log_session_summary")
    
    session_id = state.get("session_id", "unknown")
    start_time = state.get("session_start", "unknown")
    
    if start_time != "unknown":
        duration = (datetime.now() - datetime.fromisoformat(start_time)).total_seconds()
        print(f"   ⏱️  Duration: {duration:.2f}s")
    
    print(f"   📋 Session: {session_id}")
    print(f"   💬 Messages: {len(state.get('messages', []))}")
    print(f"   🕐 Ended: {datetime.now().isoformat()}")
    
    return None

@after_agent
def cleanup_resources(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """Release resources after session."""
    print(f"\n🏁 [@after_agent] cleanup_resources")
    print(f"   🧹 Cleaning up session resources...")
    print(f"   💾 Saving session data...")
    print(f"   ✅ Cleanup complete")
    return None

print("✅ @after_agent decorators defined!")

---
# Part 2: Call-Level Decorators (CORRECTED)

## @before_model and @after_model

**Runs:** Before/After EACH model call

---

## Lab 2.1: @before_model - Per-Call Logging

In [None]:
# Global stats
model_call_stats = {
    "total_calls": 0,
    "call_history": []
}

@before_model
def log_model_call(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """Log each model call."""
    model_call_stats["total_calls"] += 1
    call_num = model_call_stats["total_calls"]
    
    print(f"\n🤖 [@before_model] log_model_call")
    print(f"   📞 Model Call #{call_num}")
    
    messages = state.get("messages", [])
    if messages:
        last_msg = str(messages[-1].content) if hasattr(messages[-1], 'content') else "No message"
        print(f"   💬 Query: {last_msg[:50]}...")
    
    call_info = {
        "call_number": call_num,
        "timestamp": datetime.now().isoformat()
    }
    model_call_stats["call_history"].append(call_info)
    
    return {"current_model_call": call_num}

@before_model
def check_rate_limit_per_call(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """Check if rate limit exceeded."""
    if model_call_stats["total_calls"] > 5:
        print(f"\n🤖 [@before_model] check_rate_limit_per_call")
        print(f"   ⚠️  Rate limit warning: {model_call_stats['total_calls']} calls")
    return None

print("✅ @before_model decorators defined!")

## Lab 2.2: @after_model - Response Tracking

In [None]:
response_stats = {
    "total_tokens": 0,
    "total_cost": 0.0
}

@after_model
def track_token_usage(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """Track token usage per model call."""
    print(f"\n✅ [@after_model] track_token_usage")
    
    messages = state.get("messages", [])
    
    # Rough estimate
    tokens = sum(len(str(m.content).split()) * 1.3 for m in messages if hasattr(m, 'content'))
    tokens = int(tokens)
    
    cost = (tokens / 1000) * 0.002
    
    response_stats["total_tokens"] += tokens
    response_stats["total_cost"] += cost
    
    call_num = state.get("current_model_call", "?")
    print(f"   💰 Call #{call_num}: ~{tokens} tokens (${cost:.4f})")
    print(f"   📊 Session total: ~{response_stats['total_tokens']} tokens (${response_stats['total_cost']:.4f})")
    
    return None

@after_model
def validate_response_quality(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """Validate model response."""
    messages = state.get("messages", [])
    if messages:
        response = str(messages[-1].content) if hasattr(messages[-1], 'content') else ""
        
        if len(response) < 10:
            print(f"\n✅ [@after_model] validate_response_quality")
            print(f"   ⚠️  Warning: Short response ({len(response)} chars)")
    
    return None

print("✅ @after_model decorators defined!")

---
# Part 3: Performance Tracking

Track model call performance and tool execution.

---

In [None]:
performance_stats = {
    "calls": [],
    "total_time": 0.0
}

tool_execution_stats = {
    "total_calls": 0,
    "by_tool": {},
    "failures": []
}

@before_model
def measure_performance_pre(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """Start performance timer."""
    print(f"\n🤖 [@before_model] measure_performance_pre")
    print(f"   ⏱️  Starting timer...")
    return {"model_call_start": time.time()}

@after_model
def measure_performance_post(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """End performance timer."""
    start_time = state.get("model_call_start")
    if start_time:
        duration = time.time() - start_time
        performance_stats["calls"].append(duration)
        performance_stats["total_time"] += duration
        
        avg_time = performance_stats["total_time"] / len(performance_stats["calls"])
        
        print(f"\n✅ [@after_model] measure_performance_post")
        print(f"   ⚡ This call: {duration:.3f}s")
        print(f"   📊 Average: {avg_time:.3f}s")
    
    return None

@after_model
def monitor_tool_execution(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """Monitor tool calls."""
    messages = state.get("messages", [])
    
    if messages:
        last_msg = messages[-1]
        if hasattr(last_msg, 'tool_calls') and last_msg.tool_calls:
            print(f"\n✅ [@after_model] monitor_tool_execution")
            for tool_call in last_msg.tool_calls:
                tool_name = tool_call.get('name', 'unknown')
                
                tool_execution_stats["total_calls"] += 1
                tool_execution_stats["by_tool"][tool_name] = tool_execution_stats["by_tool"].get(tool_name, 0) + 1
                
                print(f"   🔧 Tool: {tool_name}")
                print(f"   📞 Total tool calls: {tool_execution_stats['total_calls']}")
    
    return None

print("✅ Performance tracking defined!")

---
# Part 4: Dynamic Prompt Generation

Generate context-aware prompts based on current state.

---

In [None]:
@before_model
def generate_dynamic_context(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    """Generate context-aware system prompt."""
    user_id = state.get("current_user_id", "unknown")
    session_time = datetime.now().strftime("%A, %I:%M %p")
    
    print(f"\n🤖 [@before_model] generate_dynamic_context")
    
    # Get user info
    user_context = ""
    if user_id in EMPLOYEES:
        emp = EMPLOYEES[user_id]
        user_context = f"User: {emp['name']} ({emp['role']}, {emp['department']})"
        print(f"   👤 {user_context}")
    
    # Time-based greeting
    hour = datetime.now().hour
    greeting = "Good morning" if hour < 12 else "Good afternoon" if hour < 17 else "Good evening"
    
    print(f"   ⏰ {greeting}, {session_time}")
    
    dynamic_instructions = f"{greeting}! Current user: {user_context if user_id in EMPLOYEES else 'Unknown'}"
    
    return {"dynamic_context": dynamic_instructions}

print("✅ Dynamic context generator defined!")

---
# Part 5: Complete Demo - All Decorators Together (FIXED)

**KEY FIX:** Use `middleware` parameter instead of `pre_model_hook`/`post_model_hook`

Let's create an HR agent that uses ALL decorator patterns!

---

In [None]:
# Reset all stats
model_call_stats = {"total_calls": 0, "call_history": []}
response_stats = {"total_tokens": 0, "total_cost": 0.0}
performance_stats = {"calls": [], "total_time": 0.0}
tool_execution_stats = {"total_calls": 0, "by_tool": {}, "failures": []}

# Custom state schema
class SessionAgentState(AgentState):
    current_user_id: str = ""
    session_id: str = ""
    session_start: str = ""
    authenticated: bool = False
    current_model_call: int = 0
    model_call_start: float = 0.0
    dynamic_context: str = ""

# CORRECTED: Create comprehensive HR agent with middleware
comprehensive_hr_agent = create_agent(
    model="openai:gpt-4o-mini",
    tools=[get_employee_info, check_salary, update_employee_record],
    middleware=[  # ✅ USE MIDDLEWARE, NOT pre_model_hook/post_model_hook
        # Session-level
        verify_user_authentication,
        initialize_hr_session,
        # Call-level (before)
        log_model_call,
        check_rate_limit_per_call,
        measure_performance_pre,
        generate_dynamic_context,
        # Call-level (after)
        track_token_usage,
        validate_response_quality,
        measure_performance_post,
        monitor_tool_execution,
        # Session cleanup
        log_session_summary,
        cleanup_resources
    ],
    state_schema=SessionAgentState,
    checkpointer=InMemorySaver(),
    prompt="""You are a comprehensive HR assistant with full middleware monitoring.
    
    Help employees with their HR needs while maintaining:
    - Session tracking
    - Performance monitoring  
    - Cost tracking
    - Tool execution logging
    
    Be professional and use the dynamic context provided."""
)

print("✅ Comprehensive HR Agent created!")
print("\nActive middleware:")
print("  🚀 @before_agent: Session init, Authentication")
print("  🤖 @before_model: Logging, Rate limit, Performance, Dynamic context")
print("  ✅ @after_model: Token tracking, Validation, Performance, Tool monitoring")
print("  🏁 @after_agent: Session summary, Cleanup")

## Test the Complete System

In [None]:
config = {"configurable": {"thread_id": "decorator_demo_1"}}

print("=" * 70)
print("COMPREHENSIVE DECORATOR DEMO")
print("=" * 70)

# Test 1: Basic query
print("\n" + "="*70)
print("Test 1: Employee Info Query")
print("="*70)

result = comprehensive_hr_agent.invoke({
    "messages": [{"role": "user", "content": "Tell me about my role and department"}],
    "current_user_id": "101"
}, config)

print(f"\n🤖 Response: {result['messages'][-1].content[:150]}...")

# Test 2: Salary check
print("\n" + "="*70)
print("Test 2: Salary Check")
print("="*70)

result = comprehensive_hr_agent.invoke({
    "messages": [{"role": "user", "content": "What is my current salary?"}],
    "current_user_id": "101"
}, config)

print(f"\n🤖 Response: {result['messages'][-1].content}")

# Final stats
print("\n" + "="*70)
print("SESSION END - FINAL STATISTICS")
print("="*70)

print(f"\n📊 Model Call Statistics:")
print(f"   Total calls: {model_call_stats['total_calls']}")

print(f"\n💰 Token & Cost Statistics:")
print(f"   Total tokens: ~{response_stats['total_tokens']}")
print(f"   Total cost: ${response_stats['total_cost']:.4f}")

print(f"\n⏱️  Performance Statistics:")
if performance_stats['calls']:
    print(f"   Total time: {performance_stats['total_time']:.3f}s")
    print(f"   Average: {performance_stats['total_time']/len(performance_stats['calls']):.3f}s per call")
    print(f"   Min: {min(performance_stats['calls']):.3f}s")
    print(f"   Max: {max(performance_stats['calls']):.3f}s")

print(f"\n🔧 Tool Execution Statistics:")
print(f"   Total tool calls: {tool_execution_stats['total_calls']}")
if tool_execution_stats['by_tool']:
    print(f"   Breakdown:")
    for tool, count in tool_execution_stats['by_tool'].items():
        print(f"      • {tool}: {count} times")

print("\n✅ All middleware executed successfully!")

---
# Summary

## Key Fixes Applied

✅ **Correct Imports:**
```python
from langchain.agents.middleware import before_model, after_model, before_agent, after_agent
from langgraph.runtime import Runtime
```

✅ **Correct Function Signatures:**
```python
@before_model
def my_middleware(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
    return None  # or return a dict with state updates
```

✅ **Correct Agent Creation:**
```python
agent = create_agent(
    model="...",
    tools=[...],
    middleware=[...],  # NOT pre_model_hook or post_model_hook
    ...
)
```

## Decorator Types

| Decorator | Runs | Frequency |
|-----------|------|----------|
| `@before_agent` | Start of invocation | Once |
| `@before_model` | Before each LLM call | Multiple |
| `@after_model` | After each response | Multiple |
| `@after_agent` | End of invocation | Once |

---

**Next Steps:** Explore class-based middleware for more complex scenarios!