# Module 4: LangGraph

## Applied AI Scientist Field Notes - Expanded Edition

---


## Module 4: LangGraph - Stateful Workflows and Agent Graphs

### Topics
1. State management
2. Graph-based workflows
3. Conditional routing
4. Retry and error handling
5. Human-in-the-loop (HITL)
6. Production deployment

---

In [None]:
%pip install -q langgraph
%pip install -q typing-extensions

print('LangGraph installed!')

### Section 1: Stateful Agent Graphs

LangGraph enables:
- **State persistence** across steps
- **Conditional routing** based on state
- **Retry logic** for failed steps
- **Human approval** gates
- **Observability** through state inspection

In [None]:
from typing import TypedDict, Annotated, Callable, Dict, Any, List

class AgentState(TypedDict):
    messages: List[str]
    current_step: str
    attempts: int
    data: Dict[str, Any]

class StatefulWorkflow:
    '''Simplified LangGraph-style workflow engine'''
    
    def __init__(self):
        self.nodes = {}
        self.edges = {}
        self.state = AgentState(messages=[], current_step='start', attempts=0, data={})
    
    def add_node(self, name: str, func: Callable):
        self.nodes[name] = func
    
    def add_edge(self, from_node: str, to_node: str, condition: Callable = None):
        if from_node not in self.edges:
            self.edges[from_node] = []
        self.edges[from_node].append({'to': to_node, 'condition': condition})
    
    def run(self, initial_input: str, max_steps=10):
        self.state['messages'].append(initial_input)
        steps = 0
        
        while self.state['current_step'] != 'end' and steps < max_steps:
            current = self.state['current_step']
            
            if current in self.nodes:
                print(f'Step {steps+1}: {current}')
                self.nodes[current](self.state)
            
            next_step = 'end'
            if current in self.edges:
                for edge in self.edges[current]:
                    if edge['condition'] is None or edge['condition'](self.state):
                        next_step = edge['to']
                        break
            
            self.state['current_step'] = next_step
            steps += 1
        
        return self.state

# Example: Customer Support Workflow
def classify(state):
    msg = state['messages'][-1].lower()
    if 'refund' in msg:
        state['data']['intent'] = 'refund'
    elif 'track' in msg:
        state['data']['intent'] = 'tracking'
    else:
        state['data']['intent'] = 'general'

def handle_refund(state):
    state['attempts'] += 1
    if state['attempts'] > 2:
        state['data']['escalate'] = True
    else:
        state['data']['processed'] = True

def respond(state):
    if state['data'].get('escalate'):
        state['messages'].append('Escalated to human')
    else:
        state['messages'].append('Request processed')

wf = StatefulWorkflow()
wf.add_node('start', classify)
wf.add_node('refund', handle_refund)
wf.add_node('respond', respond)
wf.add_edge('start', 'refund', lambda s: s['data'].get('intent') == 'refund')
wf.add_edge('start', 'respond', lambda s: s['data'].get('intent') != 'refund')
wf.add_edge('refund', 'respond')
wf.add_edge('respond', 'end')

result = wf.run('I want a refund')
print(f'\nFinal state: {result["messages"]}')

### Section 2: Production Patterns

**Best practices:**
- Max step limits to prevent infinite loops
- State checkpointing for recovery
- Observability through structured logging
- Retry with exponential backoff
- Circuit breakers for external services

### Section 2: Advanced State Management with Checkpointing

Production LangGraph systems need:
- **State persistence**: Save/restore state across runs
- **Checkpointing**: Resume from failures
- **State versioning**: Track state evolution
- **State inspection**: Debug complex workflows

In [None]:
import json
import hashlib
from datetime import datetime
from typing import TypedDict, Dict, Any, List, Optional
import pickle

class CheckpointableState(TypedDict):
    '''State with checkpointing support'''
    messages: List[Dict[str, str]]
    current_step: str
    attempts: int
    data: Dict[str, Any]
    checkpoint_id: str
    created_at: str

class StateManager:
    '''Production state management with persistence'''
    
    def __init__(self, checkpoint_dir='./checkpoints'):
        self.checkpoint_dir = checkpoint_dir
        self.state_history = []
    
    def create_checkpoint(self, state: CheckpointableState) -> str:
        '''Save state checkpoint'''
        checkpoint_id = hashlib.md5(
            f"{datetime.utcnow().isoformat()}{state['current_step']}".encode()
        ).hexdigest()[:16]
        
        checkpoint = {
            'id': checkpoint_id,
            'timestamp': datetime.utcnow().isoformat(),
            'state': state.copy(),
        }
        
        # Save to disk
        filename = f"{self.checkpoint_dir}/{checkpoint_id}.json"
        with open(filename, 'w') as f:
            json.dump(checkpoint, f, indent=2)
        
        self.state_history.append(checkpoint)
        return checkpoint_id
    
    def restore_checkpoint(self, checkpoint_id: str) -> Optional[CheckpointableState]:
        '''Restore state from checkpoint'''
        filename = f"{self.checkpoint_dir}/{checkpoint_id}.json"
        try:
            with open(filename, 'r') as f:
                checkpoint = json.load(f)
            return checkpoint['state']
        except FileNotFoundError:
            return None
    
    def get_state_history(self) -> List[dict]:
        '''Get complete state evolution'''
        return self.state_history
    
    def rollback_to_step(self, step_name: str) -> Optional[CheckpointableState]:
        '''Rollback to specific step'''
        for checkpoint in reversed(self.state_history):
            if checkpoint['state']['current_step'] == step_name:
                return checkpoint['state']
        return None

class ProductionGraph:
    '''LangGraph with production features'''
    
    def __init__(self, enable_checkpointing=True):
        self.nodes = {}
        self.edges = {}
        self.state_manager = StateManager() if enable_checkpointing else None
        self.execution_log = []
    
    def add_node(self, name: str, func: Callable):
        '''Add node with error handling'''
        self.nodes[name] = func
    
    def add_conditional_edge(self, from_node: str, routing_func: Callable):
        '''Add edge with dynamic routing'''
        self.edges[from_node] = {'type': 'conditional', 'router': routing_func}
    
    def add_edge(self, from_node: str, to_node: str):
        '''Add fixed edge'''
        self.edges[from_node] = {'type': 'fixed', 'to': to_node}
    
    def run(self, initial_state: CheckpointableState, max_steps=20) -> CheckpointableState:
        '''Execute graph with checkpointing'''
        state = initial_state
        
        for step_num in range(max_steps):
            current_node = state['current_step']
            
            # Checkpoint before execution
            if self.state_manager:
                checkpoint_id = self.state_manager.create_checkpoint(state)
                state['checkpoint_id'] = checkpoint_id
            
            # Execute node
            try:
                if current_node in self.nodes:
                    print(f'Executing: {current_node}')
                    self.nodes[current_node](state)
                    
                    self.execution_log.append({
                        'step': step_num + 1,
                        'node': current_node,
                        'status': 'success',
                        'timestamp': datetime.utcnow().isoformat()
                    })
                
            except Exception as e:
                print(f'Error in {current_node}: {e}')
                self.execution_log.append({
                    'step': step_num + 1,
                    'node': current_node,
                    'status': 'error',
                    'error': str(e),
                    'timestamp': datetime.utcnow().isoformat()
                })
                
                # Retry or escalate
                state['attempts'] += 1
                if state['attempts'] >= 3:
                    state['current_step'] = 'human_review'
                continue
            
            # Determine next step
            if current_node in self.edges:
                edge = self.edges[current_node]
                if edge['type'] == 'conditional':
                    state['current_step'] = edge['router'](state)
                else:
                    state['current_step'] = edge['to']
            else:
                state['current_step'] = 'end'
            
            # Check termination
            if state['current_step'] == 'end':
                break
        
        return state
    
    def visualize_execution(self) -> str:
        '''Generate execution trace for debugging'''
        trace = '\nExecution Trace:\n' + '=' * 60 + '\n'
        for entry in self.execution_log:
            status_icon = '✓' if entry['status'] == 'success' else '✗'
            trace += f"{status_icon} Step {entry['step']}: {entry['node']}"
            if entry['status'] == 'error':
                trace += f" - Error: {entry['error']}"
            trace += '\n'
        return trace

# Example: Complex workflow with checkpointing
print('PRODUCTION LANGGRAPH WITH CHECKPOINTING')
print('=' * 90)

def analyze_input(state):
    '''Analyze user input'''
    message = state['messages'][-1]['content']
    state['data']['complexity'] = 'high' if len(message) > 100 else 'low'
    state['data']['requires_tools'] = 'code' in message.lower()

def route_by_complexity(state) -> str:
    '''Route based on complexity'''
    if state['data']['complexity'] == 'high':
        return 'detailed_processing'
    else:
        return 'quick_processing'

def detailed_processing(state):
    '''Handle complex queries'''
    state['data']['processed'] = True
    state['data']['detail_level'] = 'comprehensive'

def quick_processing(state):
    '''Handle simple queries'''
    state['data']['processed'] = True
    state['data']['detail_level'] = 'basic'

def generate_response(state):
    '''Generate final response'''
    detail = state['data'].get('detail_level', 'basic')
    state['messages'].append({
        'role': 'assistant',
        'content': f'Response generated with {detail} detail level'
    })

# Build graph
graph = ProductionGraph(enable_checkpointing=True)
graph.add_node('analyze', analyze_input)
graph.add_node('detailed_processing', detailed_processing)
graph.add_node('quick_processing', quick_processing)
graph.add_node('generate', generate_response)

graph.add_edge('analyze', 'routing')
graph.add_conditional_edge('routing', route_by_complexity)
graph.add_edge('detailed_processing', 'generate')
graph.add_edge('quick_processing', 'generate')
graph.add_edge('generate', 'end')

# Test with different complexity
for test_input in ['Simple question', 'Complex query about implementing a distributed system with fault tolerance and high availability']:
    print(f'\n--- Testing: {test_input[:50]}... ---')
    
    initial_state = CheckpointableState(
        messages=[{'role': 'user', 'content': test_input}],
        current_step='analyze',
        attempts=0,
        data={},
        checkpoint_id='',
        created_at=datetime.utcnow().isoformat()
    )
    
    final_state = graph.run(initial_state, max_steps=10)
    
    print(f'Complexity: {final_state["data"].get("complexity")}')
    print(f'Detail level: {final_state["data"].get("detail_level")}')
    print(f'Checkpoints created: {len(graph.state_manager.state_history)}')

print('\n' + graph.visualize_execution())
print('=' * 90)

### Section 3: Human-in-the-Loop (HITL) Patterns

Critical for production:
- **Approval gates**: Require human approval for risky actions
- **Review queues**: Batch similar requests for efficient review
- **Escalation logic**: Auto-escalate based on confidence/risk
- **Async approval**: Don't block on human review

In [None]:
from enum import Enum
import queue
import threading
import time

class ApprovalStatus(Enum):
    PENDING = 'pending'
    APPROVED = 'approved'
    REJECTED = 'rejected'
    TIMEOUT = 'timeout'

class HITLRequest:
    '''Human-in-the-loop approval request'''
    def __init__(self, request_id: str, action: str, context: dict, risk_level: str):
        self.request_id = request_id
        self.action = action
        self.context = context
        self.risk_level = risk_level
        self.status = ApprovalStatus.PENDING
        self.reviewer_notes = ''
        self.created_at = datetime.utcnow()

class HITLManager:
    '''Manage human-in-the-loop approvals'''
    
    def __init__(self):
        self.pending_requests = {}
        self.approval_queue = queue.Queue()
        self.auto_approve_threshold = 0.9  # Confidence threshold
    
    def request_approval(self, 
                        action: str, 
                        context: dict, 
                        risk_level: str,
                        confidence: float = 0.5,
                        timeout_seconds: int = 300) -> ApprovalStatus:
        '''Request human approval for action'''
        
        # Auto-approve low-risk, high-confidence actions
        if risk_level == 'low' and confidence >= self.auto_approve_threshold:
            print(f'Auto-approved: {action} (confidence: {confidence:.2f})')
            return ApprovalStatus.APPROVED
        
        # Create approval request
        request_id = hashlib.md5(f"{action}{time.time()}".encode()).hexdigest()[:16]
        request = HITLRequest(request_id, action, context, risk_level)
        
        self.pending_requests[request_id] = request
        self.approval_queue.put(request)
        
        print(f'Awaiting human approval for: {action} (risk: {risk_level})')
        
        # In production, this would:
        # 1. Send to review queue/UI
        # 2. Wait for human response via webhook/polling
        # 3. Timeout if no response
        
        # Simulated approval (instant for demo)
        return self._simulate_human_decision(request, timeout_seconds)
    
    def _simulate_human_decision(self, request: HITLRequest, timeout: int) -> ApprovalStatus:
        '''Simulate human reviewer decision'''
        # In production, this would wait for actual human input
        
        # Simulation logic
        if request.risk_level == 'high':
            # High risk requires careful review
            time.sleep(0.1)  # Simulate thinking
            request.status = ApprovalStatus.APPROVED if 'safe' in request.action.lower() else ApprovalStatus.REJECTED
        elif request.risk_level == 'medium':
            request.status = ApprovalStatus.APPROVED
        else:
            request.status = ApprovalStatus.APPROVED
        
        return request.status
    
    def get_pending_reviews(self) -> List[HITLRequest]:
        '''Get all pending approval requests'''
        return [r for r in self.pending_requests.values() if r.status == ApprovalStatus.PENDING]

class HITLGraph:
    '''LangGraph with human-in-the-loop approval gates'''
    
    def __init__(self):
        self.nodes = {}
        self.edges = {}
        self.hitl_manager = HITLManager()
    
    def add_node(self, name: str, func: Callable, requires_approval: bool = False, risk_level: str = 'low'):
        '''Add node with optional approval requirement'''
        self.nodes[name] = {
            'func': func,
            'requires_approval': requires_approval,
            'risk_level': risk_level
        }
    
    def run(self, state: dict, max_steps=10):
        '''Execute graph with HITL gates'''
        for step_num in range(max_steps):
            current_node = state['current_step']
            
            if current_node == 'end':
                break
            
            if current_node not in self.nodes:
                print(f'Unknown node: {current_node}')
                break
            
            node_config = self.nodes[current_node]
            
            # Check if approval required
            if node_config['requires_approval']:
                approval = self.hitl_manager.request_approval(
                    action=current_node,
                    context=state.get('data', {}),
                    risk_level=node_config['risk_level'],
                    confidence=state.get('confidence', 0.5)
                )
                
                if approval != ApprovalStatus.APPROVED:
                    print(f'Action {current_node} was {approval.value}')
                    state['current_step'] = 'rejected'
                    break
            
            # Execute node
            print(f'Executing: {current_node}')
            node_config['func'](state)
            
            # Move to next step (simplified routing)
            if current_node in self.edges:
                state['current_step'] = self.edges[current_node]
            else:
                state['current_step'] = 'end'
        
        return state

# Example: Data deletion workflow with HITL
print('\nHUMAN-IN-THE-LOOP WORKFLOW')
print('=' * 90)

def validate_request(state):
    state['data']['validated'] = True
    state['confidence'] = 0.8

def execute_deletion(state):
    '''High-risk action: delete data'''
    state['data']['deleted'] = True
    print('  [CRITICAL] Data deleted')

def send_confirmation(state):
    state['data']['confirmation_sent'] = True

hitl_graph = HITLGraph()

# Low-risk node
hitl_graph.add_node('validate', validate_request, requires_approval=False)

# High-risk node requiring approval
hitl_graph.add_node('delete', execute_deletion, requires_approval=True, risk_level='high')

# Low-risk node
hitl_graph.add_node('confirm', send_confirmation, requires_approval=False)

hitl_graph.edges = {
    'validate': 'delete',
    'delete': 'confirm',
    'confirm': 'end'
}

# Test workflow
initial_state = {
    'current_step': 'validate',
    'data': {'user_id': '12345'},
    'confidence': 0.8
}

final_state = hitl_graph.run(initial_state)

print(f'\nFinal state: {final_state["current_step"]}')
print(f'Data deleted: {final_state["data"].get("deleted", False)}')
print('=' * 90)
print('\nKEY INSIGHT: HITL is essential for high-stakes actions')
print('  - Auto-approve low-risk, high-confidence actions')
print('  - Always require review for irreversible operations')
print('  - Set appropriate timeouts')
print('  - Track approval metrics (approval rate, review time)')

## Interview Questions: LangGraph Production Systems

### For Senior/Staff Engineers

In [None]:
langgraph_interview_questions = [
    {
        'level': 'Senior',
        'question': 'Your LangGraph workflow fails 5% of the time at step 3 (of 7) due to external API timeouts. Users have to restart from the beginning. Design a robust recovery system.',
        'answer': '''
**Problem Analysis:**
- Step 3 failure rate: 5%
- No checkpointing → restart from beginning
- Wasted computation: Steps 1-2 repeated
- Poor UX: User frustration
- Cost: Unnecessary LLM calls

**Solution Architecture:**

**1. Checkpoint-Based Recovery**
```python
import pickle
import hashlib
from datetime import datetime, timedelta

class CheckpointManager:
    def __init__(self, storage_backend='redis'):
        self.storage = Redis() if storage_backend == 'redis' else {}
        self.checkpoint_ttl = 3600  # 1 hour
    
    def save_checkpoint(self, workflow_id: str, step: str, state: dict) -> str:
        '''Save state at each step'''
        checkpoint = {
            'workflow_id': workflow_id,
            'step': step,
            'state': state,
            'timestamp': datetime.utcnow().isoformat(),
        }
        
        key = f'checkpoint:{workflow_id}:{step}'
        
        # Serialize state
        serialized = pickle.dumps(checkpoint)
        
        # Save with TTL
        self.storage.setex(key, self.checkpoint_ttl, serialized)
        
        return key
    
    def load_checkpoint(self, workflow_id: str, step: str = None) -> dict:
        '''Restore state from checkpoint'''
        if step:
            key = f'checkpoint:{workflow_id}:{step}'
            checkpoint = self.storage.get(key)
            if checkpoint:
                return pickle.loads(checkpoint)
        else:
            # Find latest checkpoint
            pattern = f'checkpoint:{workflow_id}:*'
            keys = self.storage.keys(pattern)
            if keys:
                # Get most recent
                latest = max(keys, key=lambda k: self.storage.ttl(k))
                return pickle.loads(self.storage.get(latest))
        
        return None

class ResilientGraph:
    '''LangGraph with automatic recovery'''
    
    def __init__(self):
        self.checkpoint_mgr = CheckpointManager()
        self.nodes = {}
        self.max_retries = 3
    
    def add_node(self, name: str, func: Callable, 
                 idempotent: bool = True, 
                 retry_strategy: str = 'exponential_backoff'):
        '''Add node with retry configuration'''
        self.nodes[name] = {
            'func': func,
            'idempotent': idempotent,
            'retry_strategy': retry_strategy
        }
    
    def execute_node_with_retry(self, node_name: str, state: dict, workflow_id: str):
        '''Execute node with automatic retry and checkpointing'''
        node = self.nodes[node_name]
        
        for attempt in range(self.max_retries):
            try:
                # Save checkpoint before execution
                self.checkpoint_mgr.save_checkpoint(workflow_id, node_name, state)
                
                # Execute
                result = node['func'](state)
                
                # Success - update state
                state.update(result or {})
                
                return state, True
                
            except Exception as e:
                print(f'Attempt {attempt + 1} failed: {e}')
                
                # Last attempt
                if attempt == self.max_retries - 1:
                    return state, False
                
                # Retry with backoff
                if node['retry_strategy'] == 'exponential_backoff':
                    wait_time = 2 ** attempt  # 1s, 2s, 4s
                    time.sleep(wait_time)
                
                # Check if idempotent
                if not node['idempotent']:
                    # Non-idempotent operations need special handling
                    # Check if operation actually succeeded
                    if self._verify_operation(node_name, state):
                        print(f'Operation succeeded despite error (idempotency check)')
                        return state, True
        
        return state, False
    
    def run_with_recovery(self, workflow_id: str, initial_state: dict, resume: bool = False):
        '''Execute workflow with automatic recovery'''
        
        if resume:
            # Attempt to restore from checkpoint
            checkpoint = self.checkpoint_mgr.load_checkpoint(workflow_id)
            if checkpoint:
                print(f"Resuming from checkpoint: {checkpoint['step']}")
                state = checkpoint['state']
                start_step = checkpoint['step']
            else:
                print('No checkpoint found, starting from beginning')
                state = initial_state
                start_step = None
        else:
            state = initial_state
            start_step = None
        
        # Execute workflow
        step_sequence = ['step1', 'step2', 'step3', 'step4', 'step5', 'step6', 'step7']
        
        # Skip to resume point
        if start_step:
            step_sequence = step_sequence[step_sequence.index(start_step):]
        
        for step in step_sequence:
            print(f'\nExecuting: {step}')
            state, success = self.execute_node_with_retry(step, state, workflow_id)
            
            if not success:
                print(f'Step {step} failed after {self.max_retries} retries')
                print(f'Workflow can be resumed with workflow_id: {workflow_id}')
                return {'status': 'failed', 'last_step': step, 'state': state}
        
        # Clean up checkpoints on success
        self._cleanup_checkpoints(workflow_id)
        
        return {'status': 'success', 'state': state}
```

**2. Compensating Transactions (for non-idempotent operations)**
```python
class CompensatingTransaction:
    '''Handle rollback for failed workflows'''
    
    def __init__(self):
        self.compensation_log = []
    
    def record_action(self, step: str, action: dict, compensation: Callable):
        '''Record action with its compensation function'''
        self.compensation_log.append({
            'step': step,
            'action': action,
            'compensation': compensation,
            'timestamp': datetime.utcnow()
        })
    
    def rollback(self, from_step: str = None):
        '''Rollback actions from failed step backwards'''
        print('Starting rollback...')
        
        # Reverse order
        for entry in reversed(self.compensation_log):
            if from_step and entry['step'] == from_step:
                break
            
            try:
                print(f"Compensating: {entry['step']}")
                entry['compensation'](entry['action'])
            except Exception as e:
                print(f"Compensation failed for {entry['step']}: {e}")
                # Log for manual intervention
```

**3. Progress Tracking for UX**
```python
class ProgressTracker:
    '''Track and display workflow progress to users'''
    
    def __init__(self, total_steps: int):
        self.total_steps = total_steps
        self.current_step = 0
        self.step_details = {}
    
    def update(self, step_name: str, status: str, message: str = ''):
        '''Update progress'''
        self.current_step += 1
        self.step_details[step_name] = {
            'status': status,
            'message': message,
            'timestamp': datetime.utcnow().isoformat()
        }
        
        # Send to frontend
        progress_pct = (self.current_step / self.total_steps) * 100
        self._emit_progress_event({
            'step': step_name,
            'progress': progress_pct,
            'status': status,
            'message': message
        })
    
    def _emit_progress_event(self, data: dict):
        '''Emit progress to frontend (WebSocket, SSE, polling)'''
        # In production: send via WebSocket
        print(f"Progress: {data['progress']:.0f}% - {data['step']} - {data['status']}")
```

**4. Complete Solution**
```python
class ProductionWorkflow:
    '''Production workflow with full recovery support'''
    
    def __init__(self):
        self.graph = ResilientGraph()
        self.progress = ProgressTracker(total_steps=7)
        self.compensator = CompensatingTransaction()
    
    def execute(self, workflow_id: str, initial_state: dict, resume: bool = False):
        '''Execute with recovery and progress tracking'''
        
        try:
            # Run workflow
            result = self.graph.run_with_recovery(workflow_id, initial_state, resume)
            
            if result['status'] == 'success':
                self.progress.update('complete', 'success', 'Workflow completed')
                return result
            else:
                # Failed but recoverable
                self.progress.update(result['last_step'], 'failed', 'Step failed, can resume')
                return result
                
        except Exception as e:
            # Unrecoverable error - rollback
            print(f'Unrecoverable error: {e}')
            self.compensator.rollback()
            raise
```

**Results:**
- **Retry success rate**: 5% failure → 0.125% failure (97.5% improvement)
  - With 3 retries and 5% failure rate per attempt: 0.05^3 = 0.000125 = 0.0125%
- **User experience**: Resume from failure point, not restart
- **Cost savings**: No repeated work for steps 1-2
- **Observability**: Progress tracking, checkpoint history

**Production Checklist:**
- [x] Checkpoint every step
- [x] Retry with exponential backoff
- [x] Idempotency checks
- [x] Compensating transactions
- [x] Progress tracking
- [x] TTL for checkpoints (prevent storage bloat)
- [x] Monitoring/alerting on high retry rates
        ''',
    },
    {
        'level': 'Staff',
        'question': 'Design a LangGraph system that can handle 10,000 concurrent workflows, with state persistence, distributed execution, and sub-second routing decisions. Include architectural diagrams and trade-offs.',
        'answer': '''
**High-Scale LangGraph Architecture:**

**1. Distributed State Store**
```python
import redis
from redis.cluster import RedisCluster

class DistributedStateStore:
    '''Scalable state storage with sharding'''
    
    def __init__(self, redis_cluster_nodes: List[dict]):
        # Redis Cluster for horizontal scaling
        self.cluster = RedisCluster(
            startup_nodes=redis_cluster_nodes,
            decode_responses=False,
            skip_full_coverage_check=True
        )
        
        # Local cache for hot states
        self.local_cache = LRUCache(maxsize=1000)
    
    def save_state(self, workflow_id: str, state: dict):
        '''Save state with automatic sharding'''
        # Serialize
        state_bytes = pickle.dumps(state)
        
        # Compress for large states
        if len(state_bytes) > 1024:  # 1KB
            import zlib
            state_bytes = zlib.compress(state_bytes)
        
        # Save to cluster (auto-sharded by workflow_id)
        self.cluster.setex(
            f'state:{workflow_id}',
            3600,  # 1 hour TTL
            state_bytes
        )
        
        # Update local cache
        self.local_cache.put(workflow_id, state)
    
    def load_state(self, workflow_id: str) -> dict:
        '''Load state with cache'''
        # Check local cache first
        cached = self.local_cache.get(workflow_id)
        if cached:
            return cached
        
        # Load from cluster
        state_bytes = self.cluster.get(f'state:{workflow_id}')
        if not state_bytes:
            return None
        
        # Decompress if needed
        try:
            state = pickle.loads(state_bytes)
        except:
            import zlib
            state = pickle.loads(zlib.decompress(state_bytes))
        
        # Update cache
        self.local_cache.put(workflow_id, state)
        
        return state
```

**2. Async Execution Engine**
```python
import asyncio
import aiokafka

class AsyncGraphExecutor:
    '''Async execution for high concurrency'''
    
    def __init__(self, state_store: DistributedStateStore):
        self.state_store = state_store
        self.semaphore = asyncio.Semaphore(100)  # Limit concurrent LLM calls
        self.metrics = defaultdict(int)
    
    async def execute_node(self, workflow_id: str, node_name: str, state: dict):
        '''Execute single node asynchronously'''
        async with self.semaphore:
            start = time.time()
            
            try:
                # Execute node function
                result = await self._run_node_func(node_name, state)
                
                # Update state
                state.update(result)
                
                # Save to distributed store
                await asyncio.to_thread(
                    self.state_store.save_state,
                    workflow_id,
                    state
                )
                
                latency = (time.time() - start) * 1000
                self.metrics['success'] += 1
                self.metrics['total_latency_ms'] += latency
                
                return state, True
                
            except Exception as e:
                self.metrics['errors'] += 1
                return state, False
    
    async def _run_node_func(self, node_name: str, state: dict):
        '''Run node function (may call LLM)'''
        # Simulated async LLM call
        await asyncio.sleep(0.1)
        return {'processed': True}
    
    async def execute_batch(self, workflows: List[tuple]):
        '''Execute multiple workflows concurrently'''
        tasks = [
            self.execute_workflow(workflow_id, state)
            for workflow_id, state in workflows
        ]
        
        results = await asyncio.gather(*tasks, return_exceptions=True)
        
        return results
```

**3. Message Queue for Distributed Execution**
```python
class WorkflowOrchestrator:
    '''Distribute workflow execution across workers'''
    
    def __init__(self, kafka_brokers: List[str]):
        self.producer = aiokafka.AIOKafkaProducer(
            bootstrap_servers=kafka_brokers,
            value_serializer=lambda v: json.dumps(v).encode()
        )
        
        self.consumer = aiokafka.AIOKafkaConsumer(
            'workflow-tasks',
            bootstrap_servers=kafka_brokers,
            group_id='graph-workers',
            value_deserializer=lambda v: json.loads(v.decode())
        )
    
    async def submit_workflow(self, workflow_id: str, initial_state: dict):
        '''Submit workflow for distributed execution'''
        message = {
            'workflow_id': workflow_id,
            'state': initial_state,
            'current_step': 'start',
            'submitted_at': datetime.utcnow().isoformat()
        }
        
        await self.producer.send('workflow-tasks', value=message)
    
    async def worker_loop(self, executor: AsyncGraphExecutor):
        '''Worker that consumes and executes workflows'''
        async for message in self.consumer:
            workflow_data = message.value
            
            # Execute current step
            state, success = await executor.execute_node(
                workflow_data['workflow_id'],
                workflow_data['current_step'],
                workflow_data['state']
            )
            
            if success:
                # Determine next step (routing logic)
                next_step = self._route_next_step(state)
                
                if next_step != 'end':
                    # Submit next step to queue
                    await self.submit_workflow(
                        workflow_data['workflow_id'],
                        state
                    )
```

**4. Fast Routing with Decision Trees**
```python
class OptimizedRouter:
    '''Sub-second routing decisions'''
    
    def __init__(self):
        # Pre-compiled routing logic
        self.routing_tree = self._build_routing_tree()
        self.cache = LRUCache(maxsize=10000)
    
    def _build_routing_tree(self):
        '''Build decision tree for fast routing'''
        # Routing rules compiled into efficient structure
        return {
            'complexity': {
                'high': {'confidence': {'high': 'generate', 'low': 'review'}},
                'low': 'generate'
            }
        }
    
    def route(self, state: dict) -> str:
        '''Route in < 1ms'''
        # Check cache first
        cache_key = self._state_fingerprint(state)
        cached = self.cache.get(cache_key)
        if cached:
            return cached
        
        # Traverse decision tree
        node = self.routing_tree
        
        # Fast lookup
        complexity = state.get('complexity', 'low')
        if complexity in node:
            node = node[complexity]
            if isinstance(node, dict):
                confidence = state.get('confidence', 'low')
                next_step = node.get('confidence', {}).get(confidence, 'generate')
            else:
                next_step = node
        else:
            next_step = 'generate'
        
        # Cache result
        self.cache.put(cache_key, next_step)
        
        return next_step
    
    def _state_fingerprint(self, state: dict) -> str:
        '''Fast state hashing for cache keys'''
        # Hash only routing-relevant fields
        key_fields = f"{state.get('complexity')}_{state.get('confidence')}"
        return key_fields
```

**5. Complete Architecture**
```
┌─────────────────────────────────────────────────────────────┐
│                        Load Balancer                         │
│                      (10K req/sec)                          │
└────────────────────┬────────────────────────────────────────┘
                     │
        ┌────────────┴────────────┐
        │                         │
   ┌────▼─────┐            ┌─────▼─────┐
   │  API      │            │   API     │
   │  Server 1 │            │  Server N │
   └────┬──────┘            └─────┬─────┘
        │                         │
        └────────────┬────────────┘
                     │
            ┌────────▼────────┐
            │   Kafka Cluster │
            │  (workflow-tasks)│
            └────────┬─────────┘
                     │
     ┌───────────────┼───────────────┐
     │               │               │
┌────▼─────┐   ┌────▼─────┐   ┌────▼─────┐
│ Worker 1 │   │ Worker 2 │   │ Worker N │
│ (Async   │   │ (Async   │   │ (Async   │
│ Executor)│   │ Executor)│   │ Executor)│
└────┬─────┘   └────┬─────┘   └────┬─────┘
     │              │              │
     └──────────────┼──────────────┘
                    │
          ┌─────────▼──────────┐
          │   Redis Cluster    │
          │ (Distributed State)│
          └────────────────────┘
```

**Performance Characteristics:**
- **Throughput**: 10,000 workflows/sec
- **Routing latency**: < 1ms (cached), < 5ms (uncached)
- **State save/load**: < 10ms (Redis Cluster)
- **Node execution**: 100-500ms (LLM calls)
- **Total latency**: P95 < 2s for 5-node workflow

**Trade-offs:**

| Aspect | Choice | Trade-off |
|--------|--------|----------|
| State Storage | Redis Cluster | Fast but memory-intensive vs. DB (durable but slower) |
| Execution Model | Async + Message Queue | Complex but scalable vs. sync (simple but limited) |
| Routing | Decision Tree | Fast but rigid vs. LLM routing (flexible but slow) |
| Caching | Aggressive LRU | Higher memory vs. lower latency |
| Serialization | Pickle + Compression | Fast but language-specific vs. JSON (universal but slower) |

**Cost Analysis (10K workflows/sec):**
- Redis Cluster (30 nodes): $3,000/month
- Kafka (9 brokers): $2,000/month  
- Workers (50 instances): $10,000/month
- LLM API calls: $50,000/month (dominant cost)
- **Total**: ~$65,000/month

**Optimization Opportunities:**
1. Model caching (save 30% LLM cost)
2. Batched inference (2x throughput)
3. Smart routing (avoid expensive models when possible)
4. State compression (50% storage reduction)
        ''',
    },
]

for i, qa in enumerate(langgraph_interview_questions, 1):
    print(f'\n{'=' * 100}')
    print(f'Q{i} [{qa["level"]} Level]')
    print('=' * 100)
    print(f'\n{qa["question"]}\n')
    print('ANSWER:')
    print(qa['answer'])
    print()

### Section 4: Advanced Routing Patterns

Production routing needs:
- **Conditional logic**: Route based on state
- **Parallel execution**: Multiple paths simultaneously
- **Dynamic routing**: LLM decides next step
- **Fallback paths**: Handle edge cases
- **Loop detection**: Prevent infinite cycles

In [None]:
from typing import TypedDict, List, Dict, Any, Callable, Set
import time

class AdvancedGraphState(TypedDict):
    '''Enhanced state for complex workflows'''
    messages: List[dict]
    current_node: str
    visited_nodes: Set[str]
    loop_count: Dict[str, int]
    data: Dict[str, Any]
    confidence: float
    errors: List[dict]

class AdvancedRouter:
    '''Sophisticated routing with loop detection and fallbacks'''
    
    def __init__(self):
        self.routing_rules = {}
        self.fallback_routes = {}
        self.max_loops = 3
    
    def add_conditional_route(self, from_node: str, conditions: Dict[str, Callable]):
        '''Add multi-way conditional routing
        
        Example:
            conditions = {
                'high_confidence': lambda s: s['confidence'] > 0.8,
                'medium_confidence': lambda s: 0.5 < s['confidence'] <= 0.8,
                'low_confidence': lambda s: s['confidence'] <= 0.5,
            }
        '''
        self.routing_rules[from_node] = conditions
    
    def add_fallback(self, from_node: str, fallback_node: str):
        '''Add fallback route if all conditions fail'''
        self.fallback_routes[from_node] = fallback_node
    
    def route(self, state: AdvancedGraphState) -> str:
        '''Determine next node with loop detection'''
        current = state['current_node']
        
        # Check for infinite loops
        if state['loop_count'].get(current, 0) >= self.max_loops:
            print(f'Loop detected at {current}, escaping to fallback')
            return self.fallback_routes.get(current, 'human_review')
        
        # Try conditional routes
        if current in self.routing_rules:
            for next_node, condition in self.routing_rules[current].items():
                if condition(state):
                    # Update loop counter
                    state['loop_count'][next_node] = state['loop_count'].get(next_node, 0) + 1
                    return next_node
        
        # Use fallback if all conditions fail
        if current in self.fallback_routes:
            return self.fallback_routes[current]
        
        # Default: end workflow
        return 'end'

class ParallelExecutionGraph:
    '''Execute multiple branches in parallel'''
    
    def __init__(self):
        self.nodes = {}
        self.parallel_branches = {}
    
    def add_parallel_branches(self, fork_node: str, branches: List[str], join_node: str):
        '''Define parallel execution branches
        
        Example:
            fork_node: 'start'
            branches: ['analyze_style', 'analyze_security', 'analyze_performance']
            join_node: 'aggregate_results'
        '''
        self.parallel_branches[fork_node] = {
            'branches': branches,
            'join': join_node
        }
    
    async def execute_parallel(self, state: dict, branches: List[str]) -> dict:
        '''Execute branches in parallel'''
        import asyncio
        
        # Create tasks for each branch
        tasks = []
        for branch in branches:
            if branch in self.nodes:
                task = asyncio.create_task(self._execute_node_async(branch, state.copy()))
                tasks.append((branch, task))
        
        # Wait for all branches
        results = {}
        for branch, task in tasks:
            try:
                branch_result = await task
                results[branch] = branch_result
            except Exception as e:
                results[branch] = {'error': str(e)}
        
        return results
    
    async def _execute_node_async(self, node_name: str, state: dict) -> dict:
        '''Execute single node asynchronously'''
        await asyncio.sleep(0.1)  # Simulate async work
        if node_name in self.nodes:
            return self.nodes[node_name](state)
        return {}

# Example: Parallel code analysis
print('ADVANCED ROUTING DEMONSTRATION')
print('=' * 90)

def analyze_style(state):
    state['style_issues'] = ['inconsistent_naming', 'missing_docstrings']
    return state

def analyze_security(state):
    state['security_issues'] = ['sql_injection_risk']
    return state

def analyze_performance(state):
    state['performance_issues'] = ['n_squared_complexity']
    return state

def aggregate_analysis(state):
    all_issues = (
        state.get('style_issues', []) + 
        state.get('security_issues', []) + 
        state.get('performance_issues', [])
    )
    state['total_issues'] = len(all_issues)
    state['severity'] = 'high' if 'security' in str(all_issues) else 'medium'
    return state

# Mock async execution
import asyncio

async def demo_parallel():
    graph = ParallelExecutionGraph()
    graph.nodes = {
        'analyze_style': analyze_style,
        'analyze_security': analyze_security,
        'analyze_performance': analyze_performance,
        'aggregate': aggregate_analysis,
    }
    
    initial_state = {'code': 'sample code here'}
    
    print('\nExecuting parallel branches...')
    start = time.time()
    
    # Execute in parallel
    branch_results = await graph.execute_parallel(
        initial_state,
        ['analyze_style', 'analyze_security', 'analyze_performance']
    )
    
    elapsed = (time.time() - start) * 1000
    
    print(f'All branches completed in {elapsed:.0f}ms')
    print('\nResults:')
    for branch, result in branch_results.items():
        issues = result.get('style_issues') or result.get('security_issues') or result.get('performance_issues', [])
        print(f'  {branch}: {len(issues)} issues found')
    
    # Aggregate
    final_state = initial_state.copy()
    for result in branch_results.values():
        final_state.update(result)
    
    final_state = aggregate_analysis(final_state)
    print(f'\nTotal issues: {final_state["total_issues"]}')
    print(f'Severity: {final_state["severity"]}')

# Run demo
asyncio.run(demo_parallel())

print('\n' + '=' * 90)
print('KEY BENEFITS OF PARALLEL EXECUTION:')
print('  - Reduced latency: 3 × 100ms sequential = 300ms')
print('                    vs max(100ms, 100ms, 100ms) = 100ms parallel')
print('  - Better resource utilization')
print('  - Independent failures don\'t block other branches')
print('  - Can aggregate multiple perspectives')

### Section 5: Observability and Debugging

Production graphs need deep observability:
- **State inspection**: View state at any point
- **Execution traces**: Full history of decisions
- **Performance profiling**: Per-node latency
- **Error tracking**: Capture and analyze failures
- **Replay capability**: Reproduce issues

In [None]:
import json
from datetime import datetime
import uuid

class ObservableGraph:
    '''LangGraph with comprehensive observability'''
    
    def __init__(self, workflow_name: str):
        self.workflow_name = workflow_name
        self.nodes = {}
        self.traces = []
        self.node_metrics = defaultdict(lambda: {
            'calls': 0,
            'successes': 0,
            'failures': 0,
            'total_latency_ms': 0,
            'latency_samples': []
        })
    
    def add_node(self, name: str, func: Callable):
        '''Add node with automatic instrumentation'''
        self.nodes[name] = self._instrument_function(name, func)
    
    def _instrument_function(self, node_name: str, func: Callable) -> Callable:
        '''Wrap function with observability'''
        def instrumented(state: dict) -> dict:
            trace_id = state.get('trace_id', str(uuid.uuid4()))
            span_id = str(uuid.uuid4())[:8]
            
            # Start span
            span_start = time.time()
            
            # Record entry
            self.traces.append({
                'trace_id': trace_id,
                'span_id': span_id,
                'workflow': self.workflow_name,
                'node': node_name,
                'event': 'start',
                'timestamp': datetime.utcnow().isoformat(),
                'state_snapshot': state.copy(),
            })
            
            try:
                # Execute node
                result = func(state)
                
                # Record success
                latency_ms = (time.time() - span_start) * 1000
                
                self.traces.append({
                    'trace_id': trace_id,
                    'span_id': span_id,
                    'workflow': self.workflow_name,
                    'node': node_name,
                    'event': 'end',
                    'status': 'success',
                    'latency_ms': latency_ms,
                    'timestamp': datetime.utcnow().isoformat(),
                })
                
                # Update metrics
                metrics = self.node_metrics[node_name]
                metrics['calls'] += 1
                metrics['successes'] += 1
                metrics['total_latency_ms'] += latency_ms
                metrics['latency_samples'].append(latency_ms)
                
                return result
                
            except Exception as e:
                # Record failure
                latency_ms = (time.time() - span_start) * 1000
                
                self.traces.append({
                    'trace_id': trace_id,
                    'span_id': span_id,
                    'workflow': self.workflow_name,
                    'node': node_name,
                    'event': 'error',
                    'error': str(e),
                    'latency_ms': latency_ms,
                    'timestamp': datetime.utcnow().isoformat(),
                })
                
                # Update metrics
                metrics = self.node_metrics[node_name]
                metrics['calls'] += 1
                metrics['failures'] += 1
                
                raise
        
        return instrumented
    
    def get_trace(self, trace_id: str) -> List[dict]:
        '''Get complete trace by ID'''
        return [t for t in self.traces if t['trace_id'] == trace_id]
    
    def get_node_metrics(self, node_name: str = None) -> dict:
        '''Get performance metrics for nodes'''
        if node_name:
            metrics = self.node_metrics[node_name]
            return {
                'node': node_name,
                'calls': metrics['calls'],
                'success_rate': metrics['successes'] / metrics['calls'] if metrics['calls'] > 0 else 0,
                'avg_latency_ms': metrics['total_latency_ms'] / metrics['calls'] if metrics['calls'] > 0 else 0,
                'p50_latency_ms': np.percentile(metrics['latency_samples'], 50) if metrics['latency_samples'] else 0,
                'p95_latency_ms': np.percentile(metrics['latency_samples'], 95) if metrics['latency_samples'] else 0,
            }
        else:
            # All nodes
            return {name: self.get_node_metrics(name) for name in self.node_metrics.keys()}
    
    def visualize_trace(self, trace_id: str) -> str:
        '''Generate ASCII visualization of execution trace'''
        trace_events = self.get_trace(trace_id)
        
        if not trace_events:
            return 'No trace found'
        
        viz = f'\nExecution Trace: {trace_id}\n'
        viz += '=' * 80 + '\n'
        
        for event in trace_events:
            indent = '  ' * (1 if event['event'] == 'start' else 2)
            
            if event['event'] == 'start':
                viz += f"{indent}→ {event['node']} (starting...)\n"
            elif event['event'] == 'end':
                viz += f"{indent}✓ {event['node']} ({event['latency_ms']:.0f}ms)\n"
            elif event['event'] == 'error':
                viz += f"{indent}✗ {event['node']} - Error: {event['error'][:50]}...\n"
        
        viz += '=' * 80
        return viz
    
    def export_trace_json(self, trace_id: str, filename: str):
        '''Export trace for analysis'''
        trace = self.get_trace(trace_id)
        with open(filename, 'w') as f:
            json.dump(trace, f, indent=2)

# Demo observable graph
print('OBSERVABLE GRAPH WITH TRACING')
print('=' * 90)

graph = ObservableGraph('customer_support_flow')

# Add nodes
def classify_query(state):
    state['query_type'] = 'refund' if 'refund' in state['messages'][-1]['content'].lower() else 'general'
    state['confidence'] = 0.85
    return state

def process_refund(state):
    time.sleep(0.05)  # Simulate work
    state['refund_processed'] = True
    return state

def process_general(state):
    time.sleep(0.03)  # Simulate work
    state['response_generated'] = True
    return state

graph.add_node('classify', classify_query)
graph.add_node('process_refund', process_refund)
graph.add_node('process_general', process_general)

# Execute multiple workflows
for i in range(3):
    trace_id = str(uuid.uuid4())
    state = {
        'trace_id': trace_id,
        'messages': [{'role': 'user', 'content': f'Test query {i}'}],
        'current_node': 'classify',
        'visited_nodes': set(),
        'loop_count': {},
        'data': {},
        'confidence': 0.0,
        'errors': []
    }
    
    # Execute nodes
    for node in ['classify', 'process_general']:
        state['current_node'] = node
        state = graph.nodes[node](state)

print('\nNode Performance Metrics:')
print('-' * 80)
for node_name, metrics in graph.get_node_metrics().items():
    print(f'\n{node_name}:')
    for key, value in metrics.items():
        if isinstance(value, float):
            print(f'  {key}: {value:.2f}')
        else:
            print(f'  {key}: {value}')

# Show trace visualization
if graph.traces:
    sample_trace_id = graph.traces[0]['trace_id']
    print('\n' + graph.visualize_trace(sample_trace_id))

print('\n' + '=' * 90)
print('OBSERVABILITY BEST PRACTICES:')
print('  - Trace every execution with unique trace_id')
print('  - Capture state snapshots at each node')
print('  - Measure latency per node (identify bottlenecks)')
print('  - Track success/failure rates per node')
print('  - Export traces for offline analysis')
print('  - Use distributed tracing in production (Jaeger, Zipkin)')