In [None]:
import openai
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from transformers import pipeline
import argparse

# Load embedding model
model = SentenceTransformer("all-MiniLM-L6-v2")
nli_model = pipeline("text-classification", model="roberta-large-mnli")

def retrieve_legal_docs(query, documents, index, k=3):
    """Retrieve the top-k most relevant legal documents for a query."""
    query_embedding = model.encode([query], convert_to_numpy=True)
    _, indices = index.search(query_embedding, k)
    return [documents[i] for i in indices[0]]

def generate_legal_response(prompt, retrieved_docs, api_key):
    """Generate a legal AI response using OpenAI GPT-4."""
    context = "\n".join(retrieved_docs)
    full_prompt = f"Use the following legal texts to answer:\n{context}\n\nQuestion: {prompt}\nAnswer:"
    
    openai.api_key = api_key
    response = openai.ChatCompletion.create(
        model="gpt-4",
        messages=[{"role": "user", "content": full_prompt}]
    )
    return response["choices"][0]["message"]["content"]

def check_grounding(response, retrieved_docs):
    """Check if each sentence in the response is grounded in retrieved legal documents."""
    response_sentences = response.split(". ")
    results = []
    
    for sentence in response_sentences:
        for doc in retrieved_docs:
            nli_result = nli_model(f"{doc} [SEP] {sentence}")[0]
            if nli_result['label'] == 'ENTAILMENT':
                results.append((sentence, True))
                break
        else:
            results.append((sentence, False))
    
    return results

def compute_grounding_score(results):
    """Calculate the percentage of sentences explicitly supported by retrieved legal documents."""
    grounded_count = sum(1 for _, grounded in results if grounded)
    total_sentences = len(results)
    return grounded_count / total_sentences if total_sentences > 0 else 0

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Legal RAG Groundedness Detection")
    parser.add_argument("--query", type=str, required=True, help="Legal question")
    parser.add_argument("--api_key", type=str, required=True, help="OpenAI API Key")
    args = parser.parse_args()
    
    # Sample legal corpus
    documents = [
        "Under the Fourth Amendment, searches require a warrant unless exigent circumstances exist.",
        "Contract law requires mutual assent for enforceability.",
        "Habeas corpus provides the right to challenge unlawful detention."
    ]
    
    # Build FAISS index
    doc_embeddings = model.encode(documents, convert_to_numpy=True)
    index = faiss.IndexFlatL2(doc_embeddings.shape[1])
    index.add(np.array(doc_embeddings))
    
    # Retrieve documents & generate response
    retrieved_docs = retrieve_legal_docs(args.query, documents, index)
    response = generate_legal_response(args.query, retrieved_docs, args.api_key)
    
    # Check grounding & compute score
    grounding_results = check_grounding(response, retrieved_docs)
    grounding_score = compute_grounding_score(grounding_results)
    
    print("Generated Response:\n", response)
    print("\nGroundedness Score:", grounding_score)
    for sent, grounded in grounding_results:
        print(f"Sentence: {sent}\nGrounded: {grounded}\n")
