In [0]:
import mlflow
import pandas as pd
from typing import Dict, List, Any
import json
import os
from databricks.vector_search.client import VectorSearchClient
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ServingClient

class InvoiceProcessingAgent:
    def __init__(self, model_uri: str, endpoint_name: str = "claude-3-7-sonnet"):
        """
        Initialize the Invoice Processing Agent with the registered prompt model.
        
        Args:
            model_uri: URI of the registered MLflow prompt model
            endpoint_name: Name of the Databricks Foundation Model endpoint to use
        """
        self.model_uri = model_uri
        # mlflow.pyfunc.get_model_dependencies(model_uri)
        # self.prompt_model = mlflow.pyfunc.load_model(model_uri)
        self.prompts = self._load_static_prompts()
        self.endpoint_name = endpoint_name
        
        # Initialize Databricks workspace client for calling the model endpoint
        self.workspace_client = WorkspaceClient()
    
    def _load_static_prompts(self) -> Dict[str, str]:
        """Load the prompts from JSON"""
        prompts = {
            "document_classification": "Classify this document into one of the following categories: invoice, receipt, ID, contract.",
            "information_extraction": "Extract the following fields from this document: date, total amount, vendor name, transaction ID."
        }

        return prompts
    
    def _load_prompts(self) -> Dict[str, str]:
        """Load the prompts from the registered model"""
        # Access the prompts from the loaded model
        # The prompts were stored in 'invoices_prompts.json'
        model_path = self.prompt_model.context.artifacts["data_path"]
        with open(model_path, "r") as f:
            return json.load(f)
    
    def _call_claude_endpoint(self, messages: List[Dict[str, str]]) -> str:
        """
        Call the Claude 3.7 Sonnet endpoint with the given messages.
        
        Args:
            messages: List of message dictionaries with 'role' and 'content'
            
        Returns:
            Model response content
        """
        # Create the request payload for Claude
        payload = {
            "messages": messages,
            "temperature": 0.0,  # Use deterministic responses
            "max_tokens": 4096
        }
        
        # Call the endpoint
        response = self.workspace_client.serving_endpoints.query(
            name=self.endpoint_name,
            dataframe_type=DataframeType.PANDAS,
            inputs=pd.DataFrame([{"messages": messages}])
        )
        
        # Extract the response content
        return response.predictions[0]["content"]
    
    def classify_document(self, document_text: str) -> str:
        """
        Classify the document using the classification prompt.
        
        Args:
            document_text: Text content of the document to classify
            
        Returns:
            Classification result
        """
        classification_prompt = self.prompts["document_classification"]
        full_prompt = f"{classification_prompt}\n\nDocument: {document_text}"
        
        messages = [
            {"role": "system", "content": "You are an expert document classifier."},
            {"role": "user", "content": full_prompt}
        ]
        
        response = self._call_claude_endpoint(messages)
        return response.strip()
    
    def extract_information(self, document_text: str) -> Dict[str, Any]:
        """
        Extract information from the document using the extraction prompt.
        
        Args:
            document_text: Text content of the document
            
        Returns:
            Dictionary containing extracted information
        """
        extraction_prompt = self.prompts["information_extraction"]
        full_prompt = f"{extraction_prompt}\n\nDocument: {document_text}"
        
        messages = [
            {"role": "system", "content": "You are an expert at extracting structured information from documents. Return the information in JSON format."},
            {"role": "user", "content": full_prompt}
        ]
        
        response = self._call_claude_endpoint(messages)
        
        try:
            # Try to parse as JSON
            return json.loads(response)
        except json.JSONDecodeError:
            # If not valid JSON, return as text
            return {"raw_extraction": response.strip()}
    
    def process_document(self, document_text: str) -> Dict[str, Any]:
        """
        Process a document by classifying it and extracting relevant information.
        
        Args:
            document_text: Text content of the document
            
        Returns:
            Dictionary with classification and extracted information
        """
        document_class = self.classify_document(document_text)
        extracted_info = self.extract_information(document_text)
        
        return {
            "document_class": document_class,
            "extracted_information": extracted_info
        }
    
    def process_batch(self, documents: List[str]) -> List[Dict[str, Any]]:
        """
        Process a batch of documents.
        
        Args:
            documents: List of document texts
            
        Returns:
            List of processing results
        """
        results = []
        for doc in documents:
            results.append(self.process_document(doc))
        return results


# Example usage in a Databricks notebook
def main():
    # Load the registered prompt model
    model_uri = "models:/invoices_prompt_model/latest"
    
    # Create the agent with the Claude 3.7 Sonnet endpoint
    agent = InvoiceProcessingAgent(model_uri, endpoint_name="claude-3-7-sonnet")
    
    # Example document text
    document_text = """
    INVOICE
    
    Invoice #: INV-2023-1234
    Date: 2023-11-15
    
    Vendor: ABC Office Supplies
    Customer: XYZ Corporation
    
    Items:
    - Office paper (10 reams): $45.00
    - Printer ink cartridges (2): $65.00
    - Desk organizers (5): $25.00
    
    Subtotal: $135.00
    Tax (10%): $13.50
    Total: $148.50
    
    Payment due within 30 days.
    """
    
    # Process the document
    result = agent.process_document(document_text)
    print("Document Class:", result["document_class"])
    print("Extracted Information:", result["extracted_information"])
    
    # Process a batch of documents
    batch_results = agent.process_batch([document_text, "Another document text..."])
    print(f"Processed {len(batch_results)} documents")

In [0]:
main()