# Module 02: State Management - Practice Notebook

**Level:** Intermediate  
**Duration:** 4-5 hours  
**Updated:** December 2025 - Advanced state patterns

## Learning Objectives

Master state management:
- âœ… Design robust state schemas
- âœ… Create custom reducers
- âœ… Handle complex state updates
- âœ… Implement state validation
- âœ… Optimize state performance


In [None]:
# Setup
%pip install -q -U langgraph langchain python-dotenv

from typing import TypedDict, Annotated, Optional
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages

print('âœ… State management tools ready!')


---

## Exercise 1: Custom Reducers ðŸŽ¯

**Objective:** Build custom reducers for different merge strategies.

### Task
Create reducers for:
1. Max value selector
2. Set deduplication
3. Dictionary deep merge


In [None]:
# Exercise 1: Custom reducers

# Reducer 1: Keep maximum value
def keep_max(existing: float, new: float) -> float:
    """Always keep the higher value."""
    return max(existing, new)

# Reducer 2: Deduplicate items
def dedupe_list(existing: list, new: list) -> list:
    """Add only unique items."""
    existing_set = set(existing)
    return existing + [item for item in new if item not in existing_set]

# Reducer 3: Deep merge dictionaries
def deep_merge(existing: dict, new: dict) -> dict:
    """Recursively merge dictionaries."""
    result = existing.copy()
    for key, value in new.items():
        if key in result and isinstance(result[key], dict) and isinstance(value, dict):
            result[key] = deep_merge(result[key], value)
        else:
            result[key] = value
    return result

class ReducerState(TypedDict):
    max_score: Annotated[float, keep_max]
    tags: Annotated[list, dedupe_list]
    config: Annotated[dict, deep_merge]

def node1(state: ReducerState):
    return {
        'max_score': 0.7,
        'tags': ['python', 'ai'],
        'config': {'model': {'name': 'gpt-4'}}
    }

def node2(state: ReducerState):
    return {
        'max_score': 0.9,  # Higher - should be kept
        'tags': ['ai', 'langchain'],  # 'ai' duplicate - should be filtered
        'config': {'model': {'temp': 0.7}, 'max_tokens': 100}  # Should merge
    }

# Test
workflow = StateGraph(ReducerState)
workflow.add_node('node1', node1)
workflow.add_node('node2', node2)
workflow.add_edge(START, 'node1')
workflow.add_edge('node1', 'node2')
workflow.add_edge('node2', END)

app = workflow.compile()
result = app.invoke({'max_score': 0.0, 'tags': [], 'config': {}})

print(f"Max Score: {result['max_score']}")  # Expected: 0.9
print(f"Tags: {result['tags']}")  # Expected: ['python', 'ai', 'langchain']
print(f"Config: {result['config']}")  # Expected: merged dict


---

## Exercise 2: State Schema Design ðŸŽ¯

**Objective:** Design production-ready state schemas.

### Task
Create a schema for a customer support agent with proper organization.


In [None]:
# Exercise 2: Production state schema

from datetime import datetime

class SupportAgentState(TypedDict):
    # Conversation
    messages: Annotated[list, add_messages]
    
    # User info
    user_id: str
    user_tier: str  # 'free', 'premium', 'enterprise'
    
    # Issue tracking
    issue_type: Optional[str]  # 'billing', 'technical', 'general'
    priority: str  # 'low', 'medium', 'high', 'critical'
    sentiment: str  # 'positive', 'neutral', 'negative'
    
    # Processing
    needs_escalation: bool
    suggested_solution: Optional[str]
    confidence: float
    
    # Metadata
    session_id: str
    created_at: str
    resolved: bool

def classify_issue(state: SupportAgentState):
    """Classify the type of issue."""
    # Simple keyword matching (in production: use LLM)
    last_message = state['messages'][-1].content.lower()
    
    if 'payment' in last_message or 'bill' in last_message:
        return {'issue_type': 'billing', 'priority': 'high'}
    elif 'error' in last_message or 'bug' in last_message:
        return {'issue_type': 'technical', 'priority': 'medium'}
    return {'issue_type': 'general', 'priority': 'low'}

def check_escalation(state: SupportAgentState):
    """Determine if escalation needed."""
    escalate = (
        state['user_tier'] == 'enterprise' or
        state['priority'] == 'critical' or
        state['sentiment'] == 'negative'
    )
    return {'needs_escalation': escalate}

# Test the schema
print("Schema fields:")
for field, field_type in SupportAgentState.__annotations__.items():
    print(f"  - {field}: {field_type}")


---

## Exercise 3: State Validation ðŸŽ¯

**Objective:** Implement state validation to catch errors early.

### Task
Create a validation node that checks state integrity.


In [None]:
# Exercise 3: State validation

class ValidatedState(TypedDict):
    user_id: str
    email: str
    age: int
    consent: bool
    errors: list
    is_valid: bool

def validate_state(state: ValidatedState):
    """Validate all required fields."""
    errors = []
    
    # Check required fields
    if not state.get('user_id'):
        errors.append('Missing user_id')
    
    # Validate email format
    email = state.get('email', '')
    if not email or '@' not in email:
        errors.append('Invalid email format')
    
    # Validate age range
    age = state.get('age', 0)
    if age < 13 or age > 120:
        errors.append('Age must be between 13 and 120')
    
    # Check consent
    if not state.get('consent'):
        errors.append('Consent required')
    
    return {
        'errors': errors,
        'is_valid': len(errors) == 0
    }

def process_if_valid(state: ValidatedState):
    """Only processes if state is valid."""
    if state['is_valid']:
        return {'user_id': f"PROCESSED-{state['user_id']}"}
    return {}

def route_on_validation(state: ValidatedState) -> str:
    return 'process' if state['is_valid'] else 'error'

# Build workflow
workflow = StateGraph(ValidatedState)
workflow.add_node('validate', validate_state)
workflow.add_node('process', process_if_valid)
workflow.add_node('error', lambda s: {'user_id': 'ERROR'})

workflow.add_edge(START, 'validate')
workflow.add_conditional_edges(
    'validate',
    route_on_validation,
    {'process': 'process', 'error': 'error'}
)
workflow.add_edge('process', END)
workflow.add_edge('error', END)

app = workflow.compile()

# Test valid state
valid_state = {
    'user_id': 'user123',
    'email': 'user@example.com',
    'age': 25,
    'consent': True,
    'errors': [],
    'is_valid': False
}
result = app.invoke(valid_state)
print(f"Valid test - Errors: {result['errors']}, Valid: {result['is_valid']}")

# Test invalid state
invalid_state = {
    'user_id': '',
    'email': 'invalid-email',
    'age': 10,
    'consent': False,
    'errors': [],
    'is_valid': False
}
result = app.invoke(invalid_state)
print(f"\nInvalid test - Errors: {result['errors']}, Valid: {result['is_valid']}")


---

## Exercise 4: State Partitioning ðŸŽ¯

**Objective:** Organize complex state into logical sections.

### Task
Create a well-partitioned state for a multi-agent system.


In [None]:
# Exercise 4: State partitioning

class PartitionedState(TypedDict):
    # User interaction partition
    conversation: Annotated[list, add_messages]
    user_intent: str
    
    # Processing partition
    current_step: str
    steps_completed: list
    processing_errors: list
    
    # Business logic partition
    extracted_entities: dict
    computed_score: float
    recommendations: list
    
    # Metadata partition
    session_id: str
    start_time: str
    agent_version: str

def interaction_node(state: PartitionedState):
    """Updates user interaction partition."""
    return {
        'user_intent': 'query',
        'current_step': 'interaction'
    }

def processing_node(state: PartitionedState):
    """Updates processing partition."""
    return {
        'steps_completed': state['steps_completed'] + ['processed_intent'],
        'current_step': 'processing'
    }

def business_logic_node(state: PartitionedState):
    """Updates business logic partition."""
    return {
        'extracted_entities': {'user_name': 'Alice'},
        'computed_score': 0.85,
        'recommendations': ['Recommendation 1', 'Recommendation 2']
    }

print("State has 4 clear partitions:")
print("1. User interaction")
print("2. Processing status")
print("3. Business logic")
print("4. Metadata")
print("\nThis makes state easier to understand and maintain!")


---

## Exercise 5: Performance Optimization ðŸŽ¯

**Objective:** Optimize state for large-scale applications.

### Task
Implement state trimming to prevent unbounded growth.


In [None]:
# Exercise 5: State optimization

def trim_messages(existing: list, new: list, max_count: int = 20) -> list:
    """Keep only last N messages to prevent memory issues."""
    combined = existing + new
    if len(combined) > max_count:
        # Keep first message (usually system prompt) + last N-1
        return [combined[0]] + combined[-(max_count-1):]
    return combined

class OptimizedState(TypedDict):
    messages: Annotated[list, lambda e, n: trim_messages(e, n, 10)]
    turn_count: int
    summary: str  # Rolling summary instead of keeping all history

def add_message_node(state: OptimizedState):
    """Add a new message."""
    new_msg = {'role': 'user', 'content': f'Message {state["turn_count"] + 1}'}
    return {
        'messages': [new_msg],
        'turn_count': state['turn_count'] + 1
    }

def summarize_node(state: OptimizedState):
    """Create rolling summary."""
    # In production: use LLM to summarize
    return {
        'summary': f"Conversation has {len(state['messages'])} messages, {state['turn_count']} turns"
    }

# Build workflow
workflow = StateGraph(OptimizedState)
workflow.add_node('add_message', add_message_node)
workflow.add_node('summarize', summarize_node)

workflow.add_edge(START, 'add_message')
workflow.add_edge('add_message', 'summarize')
workflow.add_edge('summarize', END)

app = workflow.compile()

# Simulate 25 turns (more than max_count of 10)
state = {'messages': [{'role': 'system', 'content': 'System prompt'}], 'turn_count': 0, 'summary': ''}

for i in range(25):
    state = app.invoke(state)

print(f"After 25 turns:")
print(f"  Messages in state: {len(state['messages'])}")  # Should be 10
print(f"  Turn count: {state['turn_count']}")  # Should be 25
print(f"  Summary: {state['summary']}")
print(f"\nâœ… Memory optimized - state stays bounded!")


---

## ðŸ“š Summary

You've mastered:
- âœ… Custom reducers (max, dedupe, deep merge)
- âœ… Production state schema design
- âœ… State validation patterns
- âœ… State partitioning for clarity
- âœ… Performance optimization (trimming)

**Next:** Module 03 - Advanced Control Flow with Command! ðŸš€
