In [None]:
"""
Policy Compliance Agent - Clean implementation with RAG support
Handles policy compliance checking and policy-related questions
"""
import autogen
from typing import Dict, Any, List, Union, Optional
import json


class PolicyComplianceAgent:
    """Agent for policy compliance checking and explanation"""
    
    def __init__(self, llm_config: Dict[str, Any], vector_manager=None):
        """
        Initialize Policy Compliance Agent
        
        Args:
            llm_config: LLM configuration for autogen
            qdrant_manager: QdrantManager instance for RAG
        """
        self.llm_config = llm_config
        self.vector_manager = vector_manager
        
        self.agent = autogen.AssistantAgent(
            name="PolicyComplianceAgent",
            system_message="""You are a Policy Compliance Agent specializing in expense policy management.

Your responsibilities:
1. Check transaction compliance against company policies
2. Explain policy requirements clearly and concisely
3. Answer policy-related questions with specific examples
4. Provide actionable guidance on what users need to do

Policy Structure:
- Policies are organized by: Field → Sub-field → Relation → Value → Result
- Example: Expense → Amount → Greater than → $100 → Require receipt

Available Policy Actions:
- Block transaction (non-compliant)
- Require receipt
- Require memo
- Require approval from specific role
- Require attendee list
- Auto-generate memo

When responding:
- Be clear, concise, and professional
- Cite specific policy rules when applicable
- List all requirements in bullet points
- Explain WHY each requirement applies
- Provide examples with actual dollar amounts
- Use English for all communications

Format your responses with:
- Clear headers (##)
- Bullet points for lists
- Bold for important terms
- Examples when helpful
""",
            llm_config=llm_config,
            human_input_mode="NEVER",
            max_consecutive_auto_reply=3,
        )
    
    def check_transaction_compliance(
        self,
        transaction: Dict[str, Any],
        context: Optional[Dict[str, Any]] = None
    ) -> str:
        """
        Check if a transaction complies with policies
        
        Args:
            transaction: Transaction data (amount, category, merchant, etc.)
            context: Additional context (user, department, budget, etc.)
        
        Returns:
            Compliance report as formatted string
        """
        context = context or {}
        
        # Retrieve relevant policies from RAG
        rag_context = self._get_policy_context(transaction, context)
        
        # Build prompt
        prompt = self._build_compliance_prompt(transaction, context, rag_context)
        
        # Get agent response
        return self._get_agent_reply(prompt)
    
    def answer_policy_question(self, question: str) -> str:
        """
        Answer general policy-related questions
        
        Args:
            question: User's policy question
        
        Returns:
            Answer with policy details and examples
        """
        # Retrieve relevant policies from RAG
        rag_context = self._search_policies(question, top_k=5)
        
        # Build prompt
        prompt = f"""Answer this policy question:

**Question:** {question}

**Relevant Policy Information:**
{self._format_rag_context(rag_context)}

Provide a comprehensive answer that includes:
1. Direct answer to the question
2. Specific policy rules that apply
3. Examples with actual dollar amounts or scenarios
4. Any exceptions or special cases
5. What users should do to comply
"""
        
        # Get agent response
        return self._get_agent_reply(prompt)
    
    def explain_policy_requirements(
        self,
        category: str,
        amount: Optional[float] = None
    ) -> str:
        """
        Explain policy requirements for a specific category/amount
        
        Args:
            category: Expense category (e.g., "Travel", "Meals")
            amount: Optional amount to check specific thresholds
        
        Returns:
            Explanation of policy requirements
        """
        # Build search query
        query = f"policy requirements for {category}"
        if amount:
            query += f" expense amount ${amount}"
        
        # Retrieve relevant policies
        rag_context = self._search_policies(query, top_k=3)
        
        # Build prompt
        prompt = f"""Explain the policy requirements for:

**Category:** {category}
{f"**Amount:** ${amount}" if amount else ""}

**Relevant Policies:**
{self._format_rag_context(rag_context)}

Provide a clear explanation including:
1. What documents are required (receipt, memo, etc.)
2. Any approval requirements
3. Amount thresholds that trigger different requirements
4. Specific rules for this category
5. Examples of compliant vs non-compliant expenses
"""
        
        return self._get_agent_reply(prompt)
    
    def _get_policy_context(
        self,
        transaction: Dict[str, Any],
        context: Dict[str, Any]
    ) -> List[Dict[str, Any]]:
        """Retrieve relevant policy context from RAG"""
        if not self.vector_manager:
            return []
        
        # Build search query from transaction details
        query_parts = []
        
        if transaction.get('amount'):
            query_parts.append(f"amount ${transaction['amount']}")
        
        if transaction.get('category'):
            query_parts.append(transaction['category'])
        
        if transaction.get('merchant'):
            query_parts.append(transaction['merchant'])
        
        if context.get('department'):
            query_parts.append(f"department {context['department']}")
        
        query = " ".join(query_parts)
        
        # Search in policies collection
        try:
            results = self.vector_manager.search(
                collection_name="policies",
                query=query,
                top_k=5,
                score_threshold=0.3
            )
            return results
        except Exception as e:
            print(f"⚠️ Error retrieving policy context: {e}")
            return []
    
    def _search_policies(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """Search for relevant policies"""
        if not self.vector_manager:
            return []
        
        try:
            results = self.vector_manager.search(
                collection_name="policies",
                query=query,
                top_k=top_k,
                score_threshold=0.3
            )
            return results
        except Exception as e:
            print(f"⚠️ Error searching policies: {e}")
            return []
    
    def _build_compliance_prompt(
        self,
        transaction: Dict[str, Any],
        context: Dict[str, Any],
        rag_context: List[Dict[str, Any]]
    ) -> str:
        """Build prompt for compliance checking"""
        prompt = f"""Check compliance for this transaction:

## Transaction Details
- **Amount:** ${transaction.get('amount', 0)}
- **Category:** {transaction.get('category', 'N/A')}
- **Merchant:** {transaction.get('merchant', 'N/A')}
- **Date:** {transaction.get('date', 'N/A')}
- **Description:** {transaction.get('description', 'N/A')}

## Context
- **User:** {context.get('user_name', 'N/A')}
- **Department:** {context.get('department', 'N/A')}
- **Role:** {context.get('role', 'N/A')}
- **Budget ID:** {context.get('budget_id', 'N/A')}

## Relevant Policies
{self._format_rag_context(rag_context)}

## Your Task
Analyze this transaction and provide a compliance report with:

1. **Compliance Status:** Is this transaction compliant? (Yes/No/Conditional)

2. **Requirements:** List ALL requirements that apply:
   - Receipt required? (Yes/No)
   - Memo required? (Yes/No/Auto-generated)
   - Approval required? (Yes/No - from whom?)
   - Attendee list required? (Yes/No)
   - Any other requirements?

3. **Policy Explanation:** Which specific policies apply and why?

4. **Next Steps:** What should the user do to ensure compliance?

5. **Warnings:** Any potential issues or things to watch out for?

Format your response clearly with headers and bullet points.
"""
        return prompt
    
    def _format_rag_context(self, rag_context: List[Dict[str, Any]]) -> str:
        """Format RAG context for prompt"""
        if not rag_context:
            return "No specific policy information retrieved. Using general policy knowledge."
        
        formatted = []
        for i, item in enumerate(rag_context, 1):
            text = item.get('text', '')
            score = item.get('score', 0)
            metadata = item.get('metadata', {})
            
            section = metadata.get('section', 'N/A')
            category = metadata.get('category', 'N/A')
            
            formatted.append(
                f"**Policy {i}** (Relevance: {score:.2f})\n"
                f"Section: {section} | Category: {category}\n"
                f"{text}\n"
            )
        
        return "\n".join(formatted)
    
    def _get_agent_reply(self, prompt: str) -> str:
        """Get reply from agent"""
        try:
            result = self.agent.generate_reply(
                messages=[{"role": "user", "content": prompt}]
            )
            return self._extract_content(result)
        except Exception as e:
            return f"Error generating response: {str(e)}"
    
    def _extract_content(
        self,
        result: Union[str, Dict, List],
        default: str = "Unable to process request"
    ) -> str:
        """Extract content from autogen response"""
        if isinstance(result, str):
            return result
        
        if isinstance(result, dict):
            return result.get("content", default)
        
        if isinstance(result, list):
            contents = []
            for item in result:
                if isinstance(item, str):
                    contents.append(item)
                elif isinstance(item, dict):
                    content = item.get("content", "")
                    if content:
                        contents.append(content)
            
            return "\n".join(contents) if contents else default
        
        return default