# Getting Started with Self-RAG for Legal Analysis

This notebook demonstrates the basic usage of the Self-RAG system for legal document analysis.

## Overview

This implementation consists of two main components:
1. **Legal Retrieval Pipeline**: Retrieves relevant legal documents
2. **Self-RAG System**: Generates answers with self-verification

In this notebook, we'll focus on Step 1: Building and testing the retrieval pipeline.

## Setup

In [None]:
import sys
sys.path.append('..')

from src.retrieval.chunking import DocumentChunker, RecursiveCharacterTextSplitter
from src.retrieval.embedding import EmbeddingModel
from src.retrieval.indexing import VectorIndex
from src.retrieval.retriever import LegalRetriever

## 1. Document Chunking

First, let's test the recursive character text splitter (RCTS) which preserves document structure.

In [None]:
# Sample legal text
sample_text = """
To establish negligence, a plaintiff must prove four essential elements:

1. Duty of Care: The defendant owed a legal duty of care to the plaintiff. This duty arises from the relationship between the parties and the foreseeability of harm.

2. Breach of Duty: The defendant breached that duty through action or inaction. The standard is whether a reasonable person would have acted differently under the same circumstances.

3. Causation: The defendant's breach was the actual and proximate cause of the plaintiff's injury. Both cause-in-fact and legal causation must be established.

4. Damages: The plaintiff suffered actual, compensable damages as a result of the defendant's breach. Damages may include economic losses, pain and suffering, and other harms.

Each element must be proven by a preponderance of the evidence, meaning it is more likely than not that the defendant was negligent.
"""

# Create chunker
splitter = RecursiveCharacterTextSplitter(
    chunk_size=200,
    chunk_overlap=30
)

# Split text
chunks = splitter.split_text(sample_text)

print(f"Split into {len(chunks)} chunks:\n")
for i, chunk in enumerate(chunks):
    print(f"Chunk {i+1} ({len(chunk)} chars):")
    print(chunk)
    print("-" * 80)

## 2. Embedding Generation

Convert text to dense vector embeddings using sentence transformers.

In [None]:
# Create embedding model (using smaller model for CPU)
print("Loading embedding model...")
embedding_model = EmbeddingModel(
    model_name="sentence-transformers/all-MiniLM-L6-v2",  # Fast, small model
    device="cpu",
    batch_size=8
)

# Encode chunks
print("\nEncoding chunks...")
embeddings = embedding_model.encode(chunks)

print(f"\nEmbedding shape: {embeddings.shape}")
print(f"Embedding dimension: {embedding_model.get_embedding_dim()}")

## 3. Test Similarity Search

Let's see how well embeddings capture semantic similarity.

In [None]:
# Test query
query = "What must be proven to show causation?"
query_embedding = embedding_model.encode(query)

# Calculate similarities
similarities = embedding_model.similarity(
    query_embedding.reshape(1, -1),
    embeddings
)

# Show results
print(f"Query: {query}\n")
print("Most similar chunks:\n")

# Sort by similarity
ranked_indices = similarities[0].argsort()[::-1]

for rank, idx in enumerate(ranked_indices[:3], 1):
    print(f"{rank}. Similarity: {similarities[0][idx]:.4f}")
    print(f"   {chunks[idx][:100]}...\n")

## 4. Complete Retrieval Pipeline

Now let's use the complete retrieval pipeline with FAISS indexing.

In [None]:
# Sample legal documents
documents = [
    {
        "text": """
        To establish negligence, the plaintiff must prove four elements:
        (1) duty of care, (2) breach of that duty, (3) causation, and (4) damages.
        Each element must be proven by a preponderance of the evidence.
        """,
        "source": "negligence_basics.txt",
        "title": "Elements of Negligence"
    },
    {
        "text": """
        A breach of duty occurs when the defendant fails to exercise reasonable care.
        The standard is objective: what would a reasonable person do in similar circumstances?
        The defendant's subjective beliefs or intentions are generally irrelevant.
        """,
        "source": "negligence_basics.txt",
        "title": "Breach of Duty Standard"
    },
    {
        "text": """
        Causation requires both actual cause (cause-in-fact) and proximate cause (legal cause).
        Actual cause is often determined using the but-for test: but for the defendant's conduct,
        would the harm have occurred? Proximate cause limits liability to foreseeable consequences.
        """,
        "source": "causation.txt",
        "title": "Causation in Negligence"
    },
    {
        "text": """
        Damages in negligence cases must be actual and compensable. The plaintiff cannot recover
        for nominal damages alone. Compensatory damages may include medical expenses, lost wages,
        property damage, and pain and suffering. Punitive damages are rarely available.
        """,
        "source": "damages.txt",
        "title": "Types of Damages"
    },
    {
        "text": """
        The duty of care is determined by the relationship between parties and the foreseeability
        of harm. In general, everyone owes a duty to exercise reasonable care to avoid causing
        foreseeable harm to others. Special relationships may create heightened duties.
        """,
        "source": "duty_of_care.txt",
        "title": "Duty of Care Principles"
    }
]

In [None]:
# Create retriever with chunker
chunker_config = {
    'chunk_size': 256,
    'chunk_overlap': 30
}
chunker = DocumentChunker(chunker_config)

retriever = LegalRetriever(
    embedding_model=embedding_model,
    chunker=chunker,
    top_k=3
)

# Index documents
print("Indexing documents...")
retriever.index_documents(documents)
print(f"\nIndexed {retriever.get_num_documents()} document chunks")

In [None]:
# Test retrieval with different queries
test_queries = [
    "What are the elements of negligence?",
    "How is causation established?",
    "What is the standard for breach of duty?",
    "What types of damages can be recovered?"
]

for query in test_queries:
    print(f"\n{'='*80}")
    print(f"Query: {query}")
    print(f"{'='*80}\n")
    
    results = retriever.retrieve(query, top_k=2)
    
    for i, doc in enumerate(results, 1):
        print(f"{i}. Score: {doc['score']:.4f}")
        print(f"   Title: {doc.get('title', 'N/A')}")
        print(f"   Source: {doc.get('source', 'N/A')}")
        print(f"   Text: {doc['text'].strip()[:150]}...\n")

## 5. Save and Load Index

For larger datasets, you'll want to save the index to avoid recomputing embeddings.

In [None]:
# Save index
retriever.save_index("../data/embeddings")
print("Index saved!")

# Load index in a new retriever
new_retriever = LegalRetriever(
    embedding_model=embedding_model,
    top_k=3
)
new_retriever.load_index("../data/embeddings")
print(f"Index loaded: {new_retriever.get_num_documents()} documents")

# Test loaded index
results = new_retriever.retrieve("What is causation?", top_k=2)
print(f"\nTest retrieval: {len(results)} results returned")

## Next Steps

1. **Prepare Your Data**: Collect or download legal documents (LegalBench-RAG, case law databases)
2. **Build Full Index**: Index your complete legal corpus
3. **Train Self-RAG Models**: Follow the training guide to train critic and generator models
4. **Evaluate**: Measure retrieval quality (Precision@k, Recall@k)
5. **Integrate with Self-RAG**: Use retriever with the Self-RAG generator for question answering

See `IMPLEMENTATION_GUIDE.md` for detailed instructions on training and using the Self-RAG system.