# Setup the notebook

In [13]:
#!pip install faiss-cpu numpy transformers

In [21]:
# Load metadata and embeddings
embeddings = np.load("data/faiss_embeddings.npy", allow_pickle=True)
with open("metadata_new.json", "r") as f:
    metadata = json.load(f)

# Check alignment
assert len(metadata) == embeddings.shape[0], "Mismatch between metadata and embeddings count!"

# Validate embeddings array shape
print(f"Embeddings shape: {embeddings.shape}")

# Validate metadata
print(f"Metadata entries: {len(metadata)}")

# Validate FAISS index size
print(f"FAISS index size: {index.ntotal}")



Embeddings shape: (3777, 384)
Metadata entries: 3777
FAISS index size: 3777


In [22]:
import faiss
import json
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch


In [23]:
# Load FAISS index
index = faiss.read_index("data/legal_cases_index.faiss")

# Load metadata
with open("data/metadata.json", "r") as f:
    metadata = json.load(f)

# Embedding function

In [24]:
def embed_text(text):
    """Generate embeddings for a given text."""
    model_name = "sentence-transformers/all-MiniLM-L6-v2"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)

    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        embeddings = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy()
    return embeddings


# Query function

In [25]:
def query_index(user_query, index, metadata):
    """Perform a query on the FAISS index and return top results."""
    query_embedding = embed_text(user_query)

    k = 5  # Number of top results to return
    distances, indices = index.search(query_embedding, k)

    # Debugging query results
    print(f"Query embedding shape: {query_embedding.shape}")
    print(f"Distances: {distances}")
    print(f"Indices: {indices}")

    results = []
    for i, idx in enumerate(indices[0]):
        if idx < len(metadata):  # Ensure the index is valid
            results.append({
                "rank": i + 1,
                "file": metadata[idx]["file"],
                "text": metadata[idx]["text"],
                "distance": float(distances[0][i]),
            })
        else:
            print(f"Invalid index: {idx} for metadata size: {len(metadata)}")

    return results


# Define interactive query system

In [26]:
def query_system(save_results=True):
    """Interactive query system with an option to save results."""
    while True:
        print("\nWelcome to the Legal Case Retrieval System!")
        print("Type 'exit' at any point to quit.")
        print("Select a query type:")
        print("1. Search by Name")
        print("2. Search by Abbreviation")
        print("3. Search by Decision Date")
        print("4. Search by Jurisdiction")
        print("5. Custom Legal Query")

        choice = input("Enter choice (1-5): ").strip()
        if choice.lower() == "exit":
            print("Exiting the system. Goodbye!")
            break

        query = ""
        if choice == "1":
            query = input("Enter case name: ").strip()
        elif choice == "2":
            query = input("Enter case abbreviation: ").strip()
        elif choice == "3":
            query = input("Enter decision date (YYYY-MM-DD): ").strip()
        elif choice == "4":
            query = input("Enter jurisdiction: ").strip()
        elif choice == "5":
            query = input("Enter your custom query: ").strip()
        else:
            print("Invalid choice. Please try again.")
            continue

        if query.lower() == "exit":
            print("Exiting the system. Goodbye!")
            break

        # Perform the query using precomputed embeddings and metadata
        results = query_index(query, index, metadata)

        # Display results
        print("\nQuery Results:")
        for result in results:
            print(f"Rank: {result['rank']}")
            print(f"File: {result['file']}")
            print(f"Text Snippet: {result['text'][:200]}...")
            print(f"Distance: {result['distance']:.4f}")
            print("\n")

        if save_results:
            output_file = f"query_results_{choice}.json"
            with open(output_file, "w") as f:
                json.dump(results, f, indent=4)
            print(f"Results saved to {output_file}")

        print("\nWould you like to perform another query?")


# Interactive query system

In [None]:
query_system()


Welcome to the Legal Case Retrieval System!
Type 'exit' at any point to quit.
Select a query type:
1. Search by Name
2. Search by Abbreviation
3. Search by Decision Date
4. Search by Jurisdiction
5. Custom Legal Query


Enter choice (1-5):  What about 1901-11-16?


Invalid choice. Please try again.

Welcome to the Legal Case Retrieval System!
Type 'exit' at any point to quit.
Select a query type:
1. Search by Name
2. Search by Abbreviation
3. Search by Decision Date
4. Search by Jurisdiction
5. Custom Legal Query
