In [45]:
import dotenv
import time
import os
from typing import TypedDict, List, Dict, Any, Optional, Union
from IPython.display import Image, display
import weaviate
from weaviate.auth import Auth
import weaviate.classes as wvc
from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_weaviate.vectorstores import WeaviateVectorStore
from langchain_huggingface import HuggingFaceEmbeddings

# Importing necessary libraries
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from langchain_core.output_parsers import JsonOutputParser
from tavily import TavilyClient
from langgraph.graph import StateGraph, END

import networkx as nx
import matplotlib.pyplot as plt

dotenv.load_dotenv()

True

In [46]:
class DocumentProcessor:
    """Process legal documents and create vector store"""
    
    def __init__(self, documents_dir: str = "./notes"):
        self.documents_dir = documents_dir
        self.weaviate_url = os.environ.get("WEAVIATE_URL")
        self.weaviate_api_key = os.environ.get("WEAVIATE_API_KEY")
        self.embeddings = HuggingFaceEmbeddings(
            model_name="sentence-transformers/all-MiniLM-L6-v2"
        )
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200
        )
        
    def load_documents(self) -> List[Any]:
        """Load documents from the directory"""
        try:
            loader = DirectoryLoader(
                self.documents_dir,
                glob="**/*.pdf",
                loader_cls=PyPDFLoader
            )
            documents = loader.load()
            print(f"Loaded {len(documents)} documents.")
            return documents
        except Exception as e:
            print(f"Error loading documents: {e}")
            return []
    
    def process_documents(self) -> List[Any]:
        """Split documents into chunks"""
        documents = self.load_documents()
        chunks = self.text_splitter.split_documents(documents)
        print(f"Split into {len(chunks)} chunks")
        return chunks
    
    def create_vector_store(self) -> WeaviateVectorStore:
        """Create and populate vector store with documents using LangChain's WeaviateVectorStore"""
        
        client = weaviate.connect_to_weaviate_cloud(
            cluster_url=self.weaviate_url,
            auth_credentials=Auth.api_key(self.weaviate_api_key),
        )
        
        chunks = self.process_documents()
        
        if client.collections.exists("LegalDocuments"):
            client.collections.delete("LegalDocuments")
        
        vector_store = WeaviateVectorStore.from_documents(
            documents=chunks,
            embedding=self.embeddings,
            client=client,
            index_name="LegalDocuments", 
            text_key="content",
            by_text=False
        )
        
        print(f"Successfully imported {len(chunks)} chunks into Weaviate")
        return vector_store
    
    def query_store(self, query: str, vector_store: WeaviateVectorStore, k: int = 5):
        """Query the vector store for similar documents"""
        docs = vector_store.similarity_search(query, k=k)
        return docs

In [47]:
# processor = DocumentProcessor(documents_dir="../streamlit_app/notes/")
# vector_store = processor.create_vector_store()

# results = processor.query_store("contract breach provisions", vector_store)

# for i, doc in enumerate(results):
#     print(f"\nDocument {i+1}:")
#     print(f"Source: {doc.metadata.get('source', 'Unknown')}, Page: {doc.metadata.get('page', 'Unknown')}")
#     print(doc.page_content[:150] + "...")

In [48]:
class EnhancedAgentState(TypedDict):
    """Enhanced state management for the Legal AI Assistant"""
    input: Any
    input_type: str
    processed_input: Optional[Dict[str, Any]]
    query_details: Optional[Dict[str, Any]]
    document_search_results: Optional[List[Dict[str, Any]]]
    document_search_sufficient: Optional[bool]
    web_search_results: Optional[List[Dict[str, Any]]]
    web_search_sufficient: Optional[bool]
    need_additional_search: Optional[bool]
    final_response: Optional[str]
    references: Optional[List[str]]
    conversation_history: List[Union[HumanMessage, AIMessage]]

def determine_search_sufficiency(state: EnhancedAgentState, search_type: str, threshold: float = 7.0) -> Dict[str, Any]:
    """Determine if search results are sufficient based on relevance score"""
    if search_type == "document":
        evaluation = state.get("document_search_evaluation", {})
        relevance_score = evaluation.get("Relevance Score", 0)
        sufficient = relevance_score >= threshold
        
        return {
            "document_search_sufficient": sufficient,
            "need_additional_search": not sufficient
        }
    elif search_type == "web":
        evaluation = state.get("web_search_evaluation", {})
        relevance_score = evaluation.get("Relevance Score", 0)
        sufficient = relevance_score >= threshold
        
        return {
            "web_search_sufficient": sufficient,
            "need_additional_search": state.get("need_additional_search", True) and not sufficient
        }
    else:
        raise ValueError(f"Unsupported search type: {search_type}")

In [49]:
import os
from typing import Dict, Any, Optional, List, Union
import tempfile
from PIL import Image
import pytesseract
from langchain_community.document_loaders import PyPDFLoader
from io import BytesIO

class MultimodalInputHandler:
    """Handle different types of inputs (text, image, file)"""
    
    def __init__(self):
        # Configure pytesseract path if needed
        # pytesseract.pytesseract.tesseract_cmd = r'<path_to_tesseract_executable>'
        pass
    
    def process_text(self, text: str) -> Dict[str, Any]:
        """Process plain text input"""
        return {
            "type": "text",
            "content": text,
            "metadata": {}
        }
    
    def process_image(self, image_data: Union[str, bytes, Image.Image]) -> Dict[str, Any]:
        """Process image input using OCR"""
        if isinstance(image_data, str):
            # Assuming image_data is a file path
            image = Image.open(image_data)
        elif isinstance(image_data, bytes):
            image = Image.open(BytesIO(image_data))
        else:
            image = image_data
            
        # Extract text from image using OCR
        extracted_text = pytesseract.image_to_string(image)
        
        return {
            "type": "image",
            "content": extracted_text,
            "metadata": {
                "original_format": "image",
                "image_size": image.size,
                "image_mode": image.mode
            }
        }
    
    def process_pdf(self, pdf_data: Union[str, bytes]) -> Dict[str, Any]:
        """Process PDF input"""
        if isinstance(pdf_data, bytes):
            # Create a temporary file
            with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as temp_file:
                temp_file.write(pdf_data)
                temp_file_path = temp_file.name
                
            try:
                loader = PyPDFLoader(temp_file_path)
                documents = loader.load()
                
                # Extract text from all pages
                full_text = "\n".join(doc.page_content for doc in documents)
                
                return {
                    "type": "pdf",
                    "content": full_text,
                    "metadata": {
                        "original_format": "pdf",
                        "page_count": len(documents),
                        "documents": documents
                    }
                }
            finally:
                # Remove temporary file
                os.unlink(temp_file_path)
        else:
            # Assuming pdf_data is a file path
            loader = PyPDFLoader(pdf_data)
            documents = loader.load()
            
            # Extract text from all pages
            full_text = "\n".join(doc.page_content for doc in documents)
            
            return {
                "type": "pdf",
                "content": full_text,
                "metadata": {
                    "original_format": "pdf",
                    "page_count": len(documents),
                    "documents": documents
                }
            }
    
    def process_input(self, input_data: Any, input_type: str) -> Dict[str, Any]:
        """Process any input based on its type"""
        if input_type == "text":
            return self.process_text(input_data)
        elif input_type == "image":
            return self.process_image(input_data)
        elif input_type == "pdf":
            return self.process_pdf(input_data)
        else:
            raise ValueError(f"Unsupported input type: {input_type}")

In [58]:
## setting up AI Agent workflow
class LegalAIAssistant:
    def __init__(self, weaviate_url: Optional[str] = None):
        self.llm = ChatGroq(
            model="llama3-70b-8192",
            temperature=0.6,
            api_key=os.getenv("GROQ_API_KEY")
        )
        
        self.tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
    
        self.document_processor = DocumentProcessor(documents_dir="../streamlit_app/notes/")
        self.vector_store = self.document_processor.create_vector_store()

        self.input_handler = MultimodalInputHandler()
        
        self.query_understanding_system = """You are an expert legal AI assistant specializing in understanding complex legal queries.
        Your task is to analyze the user's input and break it down into components that will guide a comprehensive legal search and response.
        Pay special attention to:
        1. Identifying the core legal issue or question
        2. Determining relevant jurisdictions
        3. Identifying specific legal domains (criminal, civil, corporate, etc.)
        4. Extracting potential subqueries that need separate investigation
        5. Identifying any time-sensitive elements
        
        Format your analysis as a structured JSON object.
        """
        
        self.document_evaluation_system = """You are an expert legal document analyst.
        Your task is to evaluate search results from a legal document database and determine if they adequately address the user's query.
        Consider:
        1. Relevance of the documents to the specific legal question
        2. Comprehensiveness of the information provided
        3. Accuracy and authority of the sources
        4. Whether the information is complete or requires additional context
        
        Assign a relevance score (0-10) and explain your reasoning.
        """
        
        self.web_evaluation_system = """You are an expert legal research analyst.
        Your task is to evaluate web search results and determine if they adequately address aspects of the user's legal query.
        Consider:
        1. Credibility of the sources (government sites, law firms, legal journals)
        2. Relevance to the specific legal question
        3. Currency of the information (especially important for evolving legal topics)
        4. Whether the results complement the document search results
        
        Assign a relevance score (0-10) and explain your reasoning.
        """
        
        self.final_response_system = """You are a comprehensive legal AI assistant tasked with providing accurate, nuanced, and helpful legal information.
        When generating your response:
        1. Focus on factual legal information and procedural guidance
        2. Clearly distinguish between established law, legal interpretation, and practical advice
        3. Include relevant citations and references to legal statutes, cases, or authorities
        4. Provide balanced perspectives where legal interpretations differ
        5. Clarify any jurisdictional limitations to your advice
        6. Include appropriate disclaimers about not providing legal advice
        
        Structure your response in a clear, logical format with headings where appropriate.
        """
        
        # Initialize prompts
        self._initialize_prompts()

        self.visualize_workflow("./img/workflow.png")
    
    def _initialize_prompts(self):
        """Initialize all prompts used by the assistant"""
        # Query Understanding Prompt
        self.query_understanding_prompt = ChatPromptTemplate.from_messages([
            ("system", self.query_understanding_system),
            ("human", """Analyze the following legal query and break it down into its key components:

            {processed_input}

            Return a structured JSON with these fields:
            - core_legal_issue: The main legal question or problem
            - jurisdiction: Relevant legal jurisdiction(s) if specified or can be inferred
            - legal_domains: List of relevant legal areas (e.g., criminal, civil, property)
            - subqueries: List of related questions that might need separate investigation
            - time_sensitivity: Any urgent aspects of the query
            - key_terms: Important legal terms mentioned or implied in the query
             
            1. **First, output JSON only** without any inline comments.
            2. **After the JSON, provide explanations** in natural language.
            """)
        ])
        
        # Document Search Evaluation Prompt
        self.document_evaluation_prompt = ChatPromptTemplate.from_messages([
            ("system", self.document_evaluation_system),
            ("human", """Evaluate these document search results for the legal query:

            Query Details: {query_details}

            Document Search Results:
            {document_search_results}

            Provide a JSON response with these fields:
            - Relevance Score: (0-10)
            - Key Matching Sections: List of sections most relevant to the query
            - Information Gaps: Legal aspects of the query not covered by these documents
            - Confidence Assessment: Your confidence in the documents answering the query correctly
            """)
        ])
        
        # Web Search Evaluation Prompt
        self.web_evaluation_prompt = ChatPromptTemplate.from_messages([
            ("system", self.web_evaluation_system),
            ("human", """Evaluate these web search results for the legal query:

            Query Details: {query_details}

            Web Search Results:
            {web_search_results}

            Provide a JSON response with these fields:
            - Relevance Score: (0-10)
            - Key Insights: Main legal information found in the results
            - Source Credibility: Assessment of the credibility of the sources
            - Information Gaps: Aspects of the query not adequately addressed
            - Comparison to Document Results: How these results complement the document search
            """)
        ])
        
        # Final Response Generation Prompt
        self.final_response_prompt = ChatPromptTemplate.from_messages([
            ("system", self.final_response_system),
            ("human", """Generate a comprehensive legal response based on the following:

            Original Query: {processed_input}
            Query Analysis: {query_details}
            Document Search Results: {document_search_results}
            Web Search Results: {web_search_results}

            Your response should include:
            1. A clear explanation of the legal concepts and principles
            2. Applicable laws, regulations, or precedents
            3. Practical guidance on how to proceed
            4. Any necessary disclaimers about jurisdictional limitations
            5. References to sources used

            Remember to remain balanced, factual, and helpful while acknowledging legal complexities.
            """)
        ])
    
    def process_input_node(self, state: EnhancedAgentState) -> Dict[str, Any]:
        """Process the input based on its type"""
        processed_input = self.input_handler.process_input(
            state['input'], 
            state['input_type']
        )

        if 'conversation_history' not in state:
            state['conversation_history'] = []
        
        return {"processed_input": processed_input}
    
    def understand_query_node(self, state: EnhancedAgentState) -> Dict[str, Any]:
        """Node for understanding the query"""
        chain = self.query_understanding_prompt | self.llm | JsonOutputParser()

        human_message = HumanMessage(
            content=state['processed_input']['content'],
            additional_kwargs={"timestamp": time.time()}
        )
        state['conversation_history'].append(human_message)
        max_history_size = 10 
        if len(state['conversation_history']) > max_history_size:
            state['conversation_history'] = state['conversation_history'][-max_history_size:]

        recent_context = state['conversation_history'][-3:]

        context_enhanced_query = state['processed_input']['content']
        query_details = chain.invoke({"processed_input": context_enhanced_query})
        
        return {
            "query_details": query_details,
            "conversation_history": state['conversation_history']
        }

    def document_search_node(self, state: EnhancedAgentState) -> Dict[str, Any]:
        """Node for searching legal documents"""
        # Extract key terms from query details
        key_terms = state['query_details'].get('key_terms', [])
        core_issue = state['query_details'].get('core_legal_issue', '')
        
        # Combine terms for search
        search_query = f"{core_issue} {' '.join(key_terms)}"
        
        # Use vector store to search documents
        search_results = self.vector_store.similarity_search_with_score(
            query=search_query,
            k=5,
        )
        
        # Format results for the LLM
        document_search_results = [
            {
                "source": result[0].metadata.get('source', 'Unknown'),
                "page": result[0].metadata.get('page', 0),
                "relevance_score": result[1],
                "content": result[0].page_content
            }
            for result in search_results
        ]
        
        # Evaluate search results
        chain = self.document_evaluation_prompt | self.llm | JsonOutputParser()
        document_evaluation = chain.invoke({
            "query_details": state['query_details'],
            "document_search_results": document_search_results
        })
        
        return {
            "document_search_results": document_search_results,
            "document_search_evaluation": document_evaluation
        }

    def evaluate_doc_search_node(self, state: EnhancedAgentState) -> Dict[str, Any]:
        """Node for evaluating document search results and deciding next steps"""
        return determine_search_sufficiency(state, "document")

    def web_search_node(self, state: EnhancedAgentState) -> Dict[str, Any]:
        """Node for web searching"""
        # Use the core legal issue and key terms for search
        core_issue = state['query_details'].get('core_legal_issue', '')
        jurisdiction = state['query_details'].get('jurisdiction', '')
        
        # Construct a more specific query for web search
        web_query = f"{core_issue} legal {jurisdiction}"
        
        web_search_results = self.tavily_client.search(
            query=web_query, 
            max_results=5,
            search_depth="advanced"
        )
        
        # Evaluate web search results
        chain = self.web_evaluation_prompt | self.llm | JsonOutputParser()
        web_search_evaluation = chain.invoke({
            "query_details": state['query_details'],
            "web_search_results": web_search_results['results']
        })
        
        return {
            "web_search_results": web_search_results['results'],
            "web_search_evaluation": web_search_evaluation
        }

    def evaluate_web_search_node(self, state: EnhancedAgentState) -> Dict[str, Any]:
        """Node for evaluating web search results and deciding next steps"""
        return determine_search_sufficiency(state, "web")

    def generate_final_response_node(self, state: EnhancedAgentState) -> Dict[str, Any]:
        """Node for generating final comprehensive response"""
        recent_conversation = state['conversation_history'][-5:]

        chain = self.final_response_prompt | self.llm
        final_response = chain.invoke({
            "processed_input": state['processed_input']['content'],
            "query_details": state['query_details'],
            "document_search_results": state['document_search_results'],
            "web_search_results": state['web_search_results'],
            "recent_conversation": recent_conversation
        })

        ai_message = AIMessage(
            content=final_response.content,
            additional_kwargs={"timestamp": time.time()}
        )
        state['conversation_history'].append(ai_message)
        
        # Collect references
        references = []
        
        # Add document references
        for doc in state.get('document_search_results', []):
            source = doc.get('source', '')
            page = doc.get('page', '')
            if source and source not in references:
                references.append(f"{source} (Page {page})")
        
        # Add web references
        for result in state.get('web_search_results', []):
            url = result.get('url', '')
            if url and url not in references:
                references.append(url)
        
        return {
            "final_response": final_response.content,
            "references": references,
            "conversation_history": state['conversation_history']
        }
    
    def additional_search_node(self, state: EnhancedAgentState) -> Dict[str, Any]:
        """Node for performing additional searches when needed"""
        # Identify information gaps from evaluations
        doc_eval = state.get('document_search_evaluation', {})
        web_eval = state.get('web_search_evaluation', {})
        
        info_gaps_doc = doc_eval.get('Information Gaps', [])
        info_gaps_web = web_eval.get('Information Gaps', [])
        
        # Combine information gaps
        all_gaps = info_gaps_doc + info_gaps_web
        
        # Use Tavily for specialized search on the gaps
        additional_results = []
        for gap in all_gaps:
            if isinstance(gap, str) and gap.strip():
                try:
                    gap_results = self.tavily_client.search(
                        query=f"{gap} legal information {state['query_details'].get('jurisdiction', '')}",
                        max_results=2,
                        search_depth="advanced"
                    )
                    additional_results.extend(gap_results['results'])
                except Exception as e:
                    print(f"Error in additional search: {e}")
        
        # Combine with existing web search results
        current_web_results = state.get('web_search_results', [])
        combined_results = current_web_results + additional_results
        
        # Remove duplicates by URL
        seen_urls = set()
        unique_results = []
        for result in combined_results:
            url = result.get('url', '')
            if url and url not in seen_urls:
                seen_urls.add(url)
                unique_results.append(result)
        
        return {
            "web_search_results": unique_results[:8],  # Limit to top 8 results
            "need_additional_search": False  # Reset flag
        }
    
    def should_perform_additional_search(self, state: EnhancedAgentState) -> str:
        """Decision node to determine if additional search is needed"""
        if state.get("need_additional_search", False):
            return "additional_search"
        return "generate_response"
    
    def build_workflow(self):
        """Construct the agentic workflow using LangGraph with decision points"""
        workflow = StateGraph(EnhancedAgentState)
        
        # Add all nodes
        workflow.add_node("process_input", self.process_input_node)
        workflow.add_node("understand_query", self.understand_query_node)
        workflow.add_node("document_search", self.document_search_node)
        workflow.add_node("evaluate_doc_search", self.evaluate_doc_search_node)
        workflow.add_node("web_search", self.web_search_node)
        workflow.add_node("evaluate_web_search", self.evaluate_web_search_node)
        workflow.add_node("additional_search", self.additional_search_node)
        workflow.add_node("generate_response", self.generate_final_response_node)
        
        # Define workflow edges with decision points
        workflow.set_entry_point("process_input")
        workflow.add_edge("process_input", "understand_query")
        workflow.add_edge("understand_query", "document_search")
        workflow.add_edge("document_search", "evaluate_doc_search")
        
        # Decision after document search
        workflow.add_conditional_edges(
            "evaluate_doc_search",
            self.should_perform_additional_search,
            {
                "additional_search": "additional_search",
                "generate_response": "web_search"
            }
        )
        
        workflow.add_edge("web_search", "evaluate_web_search")
        
        # Decision after web search
        workflow.add_conditional_edges(
            "evaluate_web_search",
            self.should_perform_additional_search,
            {
                "additional_search": "additional_search",
                "generate_response": "generate_response"
            }
        )
        
        workflow.add_edge("additional_search", "generate_response")
        workflow.set_finish_point("generate_response")
        
        return workflow.compile()
    
    def visualize_workflow(self, graph: StateGraph):
        """Visualize the LangGraph workflow with decision points and save it to a file."""
        
        try:
            img = Image(graph.get_graph().draw_mermaid_png())
            with open("mermaid_graph.png", "wb") as f:
                f.write(img.data)
            print("Image saved as mermaid_graph.png")
        except Exception as e:
            print("Error:", e)

    async def process_query(self, query: Any, input_type: str = "text", conversation_history=None):
        """Async method to process user query with any input type"""
        workflow = self.build_workflow()
        initial_state = {
            "input": query,
            "input_type": input_type,
            "conversation_history": conversation_history or []
        }
        
        result = await workflow.ainvoke(initial_state)
        return result

In [59]:
assistant = LegalAIAssistant()
query = "What are the legal implications of breaking a non-compete agreement in California?"

result = await assistant.process_query(query, "text")

Loaded 1465 documents.
Split into 5892 chunks
Successfully imported 5892 chunks into Weaviate
Error: 'str' object has no attribute 'get_graph'


In [61]:
conversation_history = result.get("conversation_history", [])

follow_up_query = "What if I signed the agreement in Nevada but now work in California?"


follow_up_result = await assistant.process_query(
    follow_up_query, 
    "text",
    conversation_history=conversation_history
)

In [62]:
updated_conversation_history = follow_up_result.get("conversation_history", [])

third_query = "Can my former employer sue me for damages?"
third_result = await assistant.process_query(
    third_query,
    "text",
    conversation_history=updated_conversation_history
)

In [63]:
def display_conversation(conversation_history):
    for message in conversation_history:
        speaker = "User" if message.type == "human" else "Assistant"
        print(f"{speaker}: {message.content}\n")

display_conversation(follow_up_result.get("conversation_history", []))

User: What are the legal implications of breaking a non-compete agreement in California?

Assistant: **Legal Implications of Breaking a Non-Compete Agreement in California**

**Introduction**

A non-compete agreement, also known as a covenant not to compete, is a contractual provision that restricts an individual's ability to engage in a particular profession, trade, or business within a specific geographic area for a certain period of time. Breaking a non-compete agreement can have significant legal implications, including potential lawsuits, damages, and injunctions. This response will provide an overview of the legal implications of breaking a non-compete agreement in California, including what constitutes a breach, available remedies, and the enforceability of non-compete agreements in California.

**What Constitutes a Breach of a Non-Compete Agreement?**

A breach of a non-compete agreement occurs when an individual or entity fails to comply with the agreed-upon restrictions, such