In [4]:
import os
import json
import random
from collections import defaultdict
from typing import List, Tuple
from pydantic import BaseModel, computed_field

from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS


dataset_name = "contractnli"
vectorstore_path = "./vectorstore/faiss_store_gte_base"
test_file = f"../data/benchmarks/{dataset_name}.json"
result_file = f"../data/results/qa_results.json"
embeddings = HuggingFaceEmbeddings(model_name="thenlper/gte-base") 

# sentence-transformers/all-MiniLM-L6-v2
# Linq-AI-Research/Linq-Embed-Mistral
# thenlper/gte-base

# Build Vector Store

In [4]:
def load_documents_with_spans(directory: str, chunk_size: int = 1000, chunk_overlap: int = 0):
    """
    Loads .txt files from a directory, splits each document's text into chunks using
    RecursiveCharacterTextSplitter, computes the span (start, end) for each chunk, and
    returns a list of Document objects with metadata (including filename, source, and span).
    """
    documents = []
    # Initialize the splitter with the desired separators and parameters.
    splitter = RecursiveCharacterTextSplitter(
        separators=["\n\n", "\n", "!", "?", ".", ":", ";", ",", " ", ""],
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        length_function=len,
        is_separator_regex=False,
        strip_whitespace=False,
    )
    
    # Process each .txt file in the directory.
    for filename in os.listdir(directory):
        if filename.endswith(".txt"):
            filepath = os.path.join(directory, filename)
            with open(filepath, "r", encoding="utf-8") as f:
                text = f.read()
            
            # Split text into chunks.
            text_splits = splitter.split_text(text)
            
            # Verify that the chunks concatenate to the original text.
            assert "".join(text_splits) == text, "Concatenated splits do not match the original text."
            
            # Compute spans and create Document objects.
            prev_index = 0
            for i, chunk_text in enumerate(text_splits):
                span = (prev_index, prev_index + len(chunk_text))
                prev_index += len(chunk_text)
                doc = Document(
                    page_content=chunk_text,
                    metadata={
                        "filename": filename,
                        "filepath": f"{dataset_name}/{filename}",
                        "span": span,  # Stores the (start, end) positions of the chunk.
                        "id": f"{filename}_chunk_{i}"
                    }
                )
                documents.append(doc)
    return documents


# Update this to the folder where your ContractNLI .txt files reside.
directory_path = f"./data/corpus/{dataset_name}"

# Load the documents, splitting each into chunks with span metadata.
documents = load_documents_with_spans(directory_path, chunk_size=500, chunk_overlap=0)
print(f"Loaded {len(documents)} document chunks with spans.")

# Build the FAISS vector store using the list of Document objects.
vectorstore = FAISS.from_documents(documents, embeddings)

# Save the FAISS vector store locally for later retrieval.
vectorstore.save_local(vectorstore_path)
print(f"FAISS vector store saved locally at {vectorstore_path}.")

Loaded 3307 document chunks with spans.
FAISS vector store saved locally at ./vectorstore/faiss_store_gte_base.


In [3]:
# import shutil

# # Check if the directory exists
# if os.path.exists(vectorstore_path):
#     shutil.rmtree(vectorstore_path)
#     print(f"Deleted the FAISS vector store at: {vectorstore_path}")
# else:
#     print(f"No FAISS vector store found at: {vectorstore_path}")


Deleted the FAISS vector store at: ./vectorstore/faiss_store_gte_base


# Evaluation

In [8]:
#############################
# Define Data Models
#############################

class QASnippet(BaseModel):
    file_path: str
    span: Tuple[int, int]
    answer: str

class QAGroundTruth(BaseModel):
    query: str
    snippets: List[QASnippet]

class RetrievedSnippet(BaseModel):
    file_path: str
    span: Tuple[int, int]
    text: str      # Retrieved text content from the FAISS vectorstore
    score: float   # Relevance score returned by similarity search

class QAResult(BaseModel):
    qa_gt: QAGroundTruth
    retrieved_snippets: List[RetrievedSnippet]

    @computed_field
    @property
    def precision(self) -> float:
        total_retrieved_len = 0
        relevant_retrieved_len = 0
        for snippet in self.retrieved_snippets:
            total_retrieved_len += snippet.span[1] - snippet.span[0]
            # Compare with each ground-truth snippet (they are guaranteed not to overlap)
            for gt_snippet in self.qa_gt.snippets:
                if snippet.file_path == gt_snippet.file_path:
                    common_min = max(snippet.span[0], gt_snippet.span[0])
                    common_max = min(snippet.span[1], gt_snippet.span[1])
                    if common_max > common_min:
                        relevant_retrieved_len += common_max - common_min
        if total_retrieved_len == 0:
            return 0
        return relevant_retrieved_len / total_retrieved_len

    @computed_field
    @property
    def recall(self) -> float:
        total_relevant_len = 0
        relevant_retrieved_len = 0
        for gt_snippet in self.qa_gt.snippets:
            total_relevant_len += gt_snippet.span[1] - gt_snippet.span[0]
            for snippet in self.retrieved_snippets:
                if snippet.file_path == gt_snippet.file_path:
                    common_min = max(snippet.span[0], gt_snippet.span[0])
                    common_max = min(snippet.span[1], gt_snippet.span[1])
                    if common_max > common_min:
                        relevant_retrieved_len += common_max - common_min
        if total_relevant_len == 0:
            return 0
        return relevant_retrieved_len / total_relevant_len

#############################
# Helper Functions
#############################

def load_groundtruth(json_file_path: str) -> List[QAGroundTruth]:
    """
    Loads the QA ground-truth data from a JSON file.
    Expected JSON format:
    {
        "tests": [
            {
                "query": "Your query...",
                "snippets": [
                    {
                        "file_path": "path/to/file.txt",
                        "span": [start, end],
                        "answer": "The answer text..."
                    },
                    ...
                ]
            },
            ...
        ]
    }
    """
    with open(json_file_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    groundtruth_tests = []
    for test in data.get("tests", []):
        snippets = [QASnippet(**snippet) for snippet in test["snippets"]]
        groundtruth_tests.append(QAGroundTruth(query=test["query"], snippets=snippets))
    return groundtruth_tests

def perform_retrieval(vectorstore: FAISS, query: str, k: int = 5) -> List[RetrievedSnippet]:
    """
    Uses the FAISS vector store to perform a similarity search on the given query using
    similarity_search_with_relevance_score. Converts the returned Document objects into 
    RetrievedSnippet instances using the metadata, and also stores the relevance score.
    """
    # Retrieve a list of tuples: (Document, relevance_score)
    docs_and_scores: List[Tuple[Document, float]] = vectorstore.similarity_search_with_relevance_scores(query, k=k)
    retrieved = []
    for doc, score in docs_and_scores:
        # Retrieve file path and span from metadata.
        file_path = doc.metadata.get("filepath")
        span = doc.metadata.get("span", (0, len(doc.page_content)))
        retrieved.append(RetrievedSnippet(file_path=file_path, span=span, text=doc.page_content, score=score))
    return retrieved

#############################
# Main Execution
#############################

# 1. Load ground-truth data.
groundtruth_tests = load_groundtruth(test_file)

# 2. Load the FAISS vector store that was previously created.
vectorstore = FAISS.load_local(vectorstore_path, embeddings, allow_dangerous_deserialization=True)

# 3. Evaluate retrieval performance for different k values.
k_values = [1, 3, 5, 10]
all_results = []

for gt in groundtruth_tests:
    for k in k_values:
        retrieved_snippets = perform_retrieval(vectorstore, gt.query, k=k)
        qa_result = QAResult(qa_gt=gt, retrieved_snippets=retrieved_snippets)
        # Create a dictionary of results for this query and k.
        result_dict = {
            "query": gt.query,
            "k": k,
            "precision": qa_result.precision,
            "recall": qa_result.recall,
            "ground_truth": [gt_snippet.dict() for gt_snippet in gt.snippets],
            "retrieved": [snippet.dict() for snippet in retrieved_snippets]
        }
        all_results.append(result_dict)

# 4. Save the results as JSON.
with open(result_file, "w", encoding="utf-8") as f:
    json.dump(all_results, f, indent=2)

print(f"QA results saved to {result_file}.")

/var/folders/hk/j9r7jggx4dxgt8gmzj_c2z080000gn/T/ipykernel_7045/444561711.py:132: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  "ground_truth": [gt_snippet.dict() for gt_snippet in gt.snippets],
/var/folders/hk/j9r7jggx4dxgt8gmzj_c2z080000gn/T/ipykernel_7045/444561711.py:133: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.10/migration/
  "retrieved": [snippet.dict() for snippet in retrieved_snippets]


QA results saved to ./data/qa_results.json.


In [2]:
with open(result_file, "r", encoding="utf-8") as f:
    results = json.load(f)

# Dictionary to collect precision and recall values per K.
# The keys will be the K value and the values a list of (precision, recall) tuples.
metrics_by_k = defaultdict(list)

for item in results:
    k = item.get("k")
    precision = item.get("precision", 0)
    recall = item.get("recall", 0)
    metrics_by_k[k].append((precision, recall))

# Compute the average precision and recall for each K.
avg_metrics = {}
for k, metrics in metrics_by_k.items():
    if metrics:
        total_precision = sum(m[0] for m in metrics)
        total_recall = sum(m[1] for m in metrics)
        count = len(metrics)
        avg_precision = total_precision / count
        avg_recall = total_recall / count
    else:
        avg_precision = 0
        avg_recall = 0
    avg_metrics[k] = {"avg_precision": avg_precision, "avg_recall": avg_recall}

# Print the results.
print("Average Precision and Recall for each K:")
for k in sorted(avg_metrics.keys()):
    metrics = avg_metrics[k]
    print(f"K = {k}: Average Precision = {metrics['avg_precision']:.4f}, Average Recall = {metrics['avg_recall']:.4f}")

Average Precision and Recall for each K:
K = 1: Average Precision = 0.0299, Average Recall = 0.0355
K = 3: Average Precision = 0.0213, Average Recall = 0.0565
K = 5: Average Precision = 0.0191, Average Recall = 0.0693
K = 10: Average Precision = 0.0174, Average Recall = 0.0989


In [14]:
qidx = random_number = random.randint(1, len(results))
results[qidx]

{'query': "Consider EFCA's Non-Disclosure Agreement; Does the document permit the Receiving Party to create a copy of some Confidential Information under certain circumstances?",
 'k': 3,
 'precision': 0.20080321285140562,
 'recall': 1.0,
 'ground_truth': [{'file_path': 'contractnli/EFCAConfidentialityAgreement.txt',
   'span': [2459, 2609],
   'answer': 'Copies or reproductions shall not be made except to the extent reasonably necessary and all copies made shall be the property of the disclosing party.'}],
 'retrieved': [{'file_path': 'contractnli/EFCAConfidentialityAgreement.txt',
   'span': [2280, 2609],
   'text': '. EFCA shall ensure that disclosure of such Confidential Information is restricted to those employees or directors of EFCA and EFCA’s principals having the need to know the same. Copies or reproductions shall not be made except to the extent reasonably necessary and all copies made shall be the property of the disclosing party.',
   'score': 0.8845122402835409},
  {'file

In [15]:
qidx = random_number = random.randint(1, len(results))
results[qidx]

{'query': 'Consider the Non-Disclosure Agreement between IGC and LSE; Does the document allow the Receiving Party to share some Confidential Information with third parties, including consultants, agents, and professional advisors?',
 'k': 5,
 'precision': 0.0,
 'recall': 0.0,
 'ground_truth': [{'file_path': 'contractnli/IGC-Non-Disclosure-Agreement-LSE-Sample.txt',
   'span': [4634, 4736],
   'answer': 'Representative means employees, agents, officers, advisers and other representatives of the Recipient.'},
  {'file_path': 'contractnli/IGC-Non-Disclosure-Agreement-LSE-Sample.txt',
   'span': [6909, 7094],
   'answer': "The Recipient may disclose the Disclosing Party's Confidential Information to those of its Representatives who need to know this Confidential Information for the Purpose, provided that:"}],
 'retrieved': [{'file_path': 'contractnli/tpi-non-disclosure-agreement_1.txt',
   'span': [0, 24],
   'text': 'NON-DISCLOSURE AGREEMENT',
   'score': 0.8735252839098877},
  {'file_pat

In [3]:
qidx = random_number = random.randint(1, len(results))
results[qidx]

{'query': "Consider Grindrod SA's Non-Disclosure Agreement; Does the document allow the Receiving Party to independently develop information that is similar to the Confidential Information?",
 'k': 10,
 'precision': 0.16828675577156743,
 'recall': 1.0,
 'ground_truth': [{'file_path': 'contractnli/Grindrod%20SA%20Confidentiality%20and%20Non-Disclosure%20Undertaking.txt',
   'span': [1942, 2219],
   'answer': 'Confidential Information also excludes information in the public domain for a reason, other than a breach of this Confidentiality and Non-Disclosure Undertaking, with any party, or independently developed by the Vendor without reference to information provided by, Grindrod SA;'}],
 'retrieved': [{'file_path': 'contractnli/Grindrod%20SA%20Confidentiality%20and%20Non-Disclosure%20Undertaking.txt',
   'span': [0, 258],
   'text': 'CONFIDENTIALITY AND NON-DISCLOSURE UNDERTAKING\nConfidentiality and Non-Disclosure Undertaking Between Grindrod South Africa (Proprietary) Limited and _____