In [7]:
import time
import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer
from openai import OpenAI
from typing import List, Tuple
from sentence_transformers import SentenceTransformer, util, CrossEncoder #CrossEncoder: A specialized model for sentence pair classification (used for NLI).

def load_data():
    """Load queries, responses, and context from a dataset."""
    return [
        {
            "query": "When does the duty to preserve documents end?",
            "context": [
                "The duty to preserve evidence arises when a party reasonably anticipates litigation.",
                "The duty generally requires the party to suspend its routine document retention policy.",
                "The duty to preserve evidence ends when the litigation is resolved."
            ],
            "responses": [
                "The duty to preserve evidence ends when the litigation is resolved.",  # Grounded
                "The duty to preserve evidence ends immediately after a motion to dismiss is filed."  # Ungrounded
                

            ]
        },
        {
            "query": "Can a lawyer disclose confidential client information?",
            "context": [
                "Under attorney-client privilege, a lawyer cannot disclose confidential client communications unless the client consents.",
                "However, privilege does not apply if the lawyer reasonably believes disclosure is necessary to prevent imminent harm.",
                "Certain exceptions, such as court orders, may override privilege."
            ],
            "responses": [
                "A lawyer cannot disclose confidential client information unless the client consents, or an exception applies, such as preventing imminent harm.",  # Grounded
                "A lawyer may disclose confidential client information whenever they believe it is in the client’s best interest."  # Ungrounded
            ]
        },
        {
            "query": "What happens if one party breaches a contract?",
            "context": [
                "If a party breaches a contract, the non-breaching party may seek damages, specific performance, or termination of the contract.",
                "Compensatory damages aim to restore the injured party to the position they would have been in had the breach not occurred.",
                "Specific performance is ordered when monetary damages are insufficient."
            ],
            "responses": [
                "If a party breaches a contract, the non-breaching party can seek damages or specific performance, depending on the circumstances.",  # Grounded
                "If a party breaches a contract, they are automatically sentenced to jail."  # Ungrounded
            ]
        },
        {
            "query": "Can an employer fire an employee without cause?",
            "context": [
                "In at-will employment states, an employer can terminate an employee without cause unless there is a contract or statutory protection.",
                "Employees under a collective bargaining agreement or specific employment contract may have additional protections.",
                "Federal laws prohibit termination based on discrimination or retaliation."
            ],
            "responses": [
                "In at-will employment states, an employer can terminate an employee without cause, except in cases of discrimination or contractual protections.",  # Grounded
                "Employers can never fire an employee without cause, even in at-will employment states."  # Ungrounded
            ]
        },
        {
            "query": "Where is Universal Studios?",
            "context": [
                "Universal Studios is in Orlando, Florida.",
                "Employees get free park tickets every quarter.",
                "Federal laws prohibit termination based on discrimination or retaliation."
            ],
            "responses": [
                "Universal Studios is in Los Angeles.",  # Grounded
                "Employees get unlimited free tickets."  # Ungrounded
            ]
        }
    ]

def model_scores_to_class_id(scores):
    #This function converts the output scores from the NLI model into class IDs.
    #NLI models classify text into 3 categories:
        #Index 0 → Contradiction (response contradicts the context)
        #Index 1 → Neutral (response is unrelated to the context)
        #Index 2 → Entailment (response logically follows the context)
    #First, it finds the index with the highest score (argmax()).
    #Then, it maps the original model output to a different ordering (ensuring that 2 still represents entailment).

    model_class_id = scores.argmax(axis=1)
    return np.where(
        model_class_id == 0,
        2,
        np.where(
            model_class_id == 1,
            0,
            1
        )
    )

def compute_cosine_similarity(model, context_sentences: List[str], response_sentences: List[str]) -> float:
    #Measures semantic similarity between the response and the context.
    #Uses Sentence-BERT (SBERT) embeddings to convert sentences into vectors.
    #How It Works
    #Convert context and response into embeddings (numerical vectors).
    #Compute cosine similarity (measures closeness of vectors).
    #Return the average similarity score.
    """Compute the cosine similarity between context and response sentences."""
    context_embeddings = model.encode(context_sentences, convert_to_tensor=True)
    response_embeddings = model.encode(response_sentences, convert_to_tensor=True)
    similarities = util.pytorch_cos_sim(response_embeddings, context_embeddings)
    max_similarities = torch.max(similarities, dim=1)[0]
    return torch.mean(max_similarities).item()

def evaluate_nli(model, tokenizer, context: List[str], response: str) -> float:
    #Uses Natural Language Inference (NLI) to check whether the response logically follows from the context.
    #How It Works
    #Concatenates the context sentences into a single text.
    #Encodes and tokenizes the input.
    #Passes input to the NLI model and gets the output logits.
    #Applies softmax to get probability scores.
    #Uses model_scores_to_class_id() to map the output to a class (contradiction, neutral, or entailment).
    """Evaluate Natural Language Inference (NLI) for factual consistency."""
    context_text = " ".join(context)
    inputs = tokenizer(context_text, response, return_tensors='pt', truncation=True)
    outputs = model(**inputs)
    
    if hasattr(outputs, "logits"):
        logits = outputs.logits
    else:
        raise AttributeError(f"Unexpected model output format: {outputs.keys()}")
    
    scores = torch.softmax(logits, dim=1).cpu().numpy()
    predicted_class = model_scores_to_class_id(scores)[0]
    return predicted_class

def assess_groundedness():
    """Main function to assess groundedness using multiple methods."""
    data = load_data()
    results = []
    
    # Load models
    sim_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    nli_model = CrossEncoder('cross-encoder/nli-deberta-base')
    
    for entry in data:
        query, context, responses = entry['query'], entry['context'], entry['responses']
        
        for response in responses:
            start_time = time.time()
            
            # Similarity-based grounding assessment
            cosine_score = compute_cosine_similarity(sim_model, context, [response])
            
            # Natural Language Inference grounding assessment
            nli_score = nli_model.predict([(context_text, response) for context_text in context])
            nli_score = model_scores_to_class_id(nli_score).mean()
            
            latency = time.time() - start_time
            
            results.append({
                "query": query,
                "response": response,
                "cosine_similarity": cosine_score,
                "nli_score": nli_score,
                "latency": latency,
                "final_groundedness": "Grounded" if (cosine_score > 0.6 and nli_score > 0.5) else "Ungrounded"
            })
    
    return results

def main():
    results = assess_groundedness()
    for res in results:
        print("Query:", res["query"])
        print("Response:", res["response"])
        print("Cosine Similarity:", res["cosine_similarity"])
        print("NLI Score:", res["nli_score"])
        print("Final Assessment:", res["final_groundedness"])
        print("Latency:", res["latency"], "seconds\n")
#Considers both semantic similarity and logical entailment.
        
main()
#Expected Output Values
#The function returns a probability (0 to 1):
#Close to 1 → Strong entailment (Response is well-grounded)
#Around 0.5 → Neutral (Uncertain if grounded)
#Close to 0 → Contradiction (Response is likely ungrounded)

Query: When does the duty to preserve documents end?
Response: The duty to preserve evidence ends when the litigation is resolved.
Cosine Similarity: 1.0
NLI Score: 1.0
Final Assessment: Grounded
Latency: 0.029059171676635742 seconds

Query: When does the duty to preserve documents end?
Response: The duty to preserve evidence ends immediately after a motion to dismiss is filed.
Cosine Similarity: 0.8163374662399292
NLI Score: 1.3333333333333333
Final Assessment: Grounded
Latency: 0.02658224105834961 seconds

Query: Can a lawyer disclose confidential client information?
Response: A lawyer cannot disclose confidential client information unless the client consents, or an exception applies, such as preventing imminent harm.
Cosine Similarity: 0.8755083084106445
NLI Score: 1.0
Final Assessment: Grounded
Latency: 0.0273129940032959 seconds

Query: Can a lawyer disclose confidential client information?
Response: A lawyer may disclose confidential client information whenever they believe it is