In [None]:
# RAG System Implementation
## RAG vs Fine-Tuning: A Comparative Study for Legal QA

This notebook implements the complete RAG (Retrieval-Augmented Generation) system using Mistral-7B and the vector database created from Indian Legal documents.

**Components:**
- **Retrieval**: FAISS vector database with legal document chunks
- **Generation**: Mistral-7B-Instruct-v0.1 for answering questions
- **Pipeline**: Query → Retrieve → Generate → Response
- **Evaluation**: Performance metrics for comparison with fine-tuning


In [None]:
## 1. Setup and Load Vector Database


In [None]:
import os
import json
import pickle
import torch
import numpy as np
import pandas as pd
from typing import List, Dict, Any
import warnings
warnings.filterwarnings('ignore')

# Transformers and RAG imports
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document

# Evaluation imports
import time
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

print("📦 RAG System Setup Complete!")

# Check if vector database exists
if not os.path.exists('./vector_db/faiss_legal_db'):
    print("❌ Vector database not found!")
    print("Please run '1_vector_database_creation.ipynb' first to create the vector database.")
    raise FileNotFoundError("Vector database not found")

print("✅ Vector database found, ready to load RAG system")


In [None]:
## 2. Load Vector Database and Embeddings


In [None]:
# Load the embedding model (same as used for vector database)
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"

print(f"🔄 Loading embedding model: {EMBEDDING_MODEL}")
embeddings = HuggingFaceEmbeddings(
    model_name=EMBEDDING_MODEL,
    model_kwargs={'device': 'cpu'},  # Use 'cuda' if GPU available
    encode_kwargs={'normalize_embeddings': True}
)

print(f"✅ Embedding model loaded")

# Load the FAISS vector database
print(f"🔄 Loading FAISS vector database...")
try:
    vectorstore = FAISS.load_local(
        "./vector_db/faiss_legal_db", 
        embeddings,
        allow_dangerous_deserialization=True  # For local files
    )
    print(f"✅ FAISS vector database loaded successfully")
    
    # Test retrieval
    test_query = "contract obligations and legal provisions"
    test_results = vectorstore.similarity_search(test_query, k=3)
    print(f"   📊 Test retrieval: Found {len(test_results)} relevant documents")
    print(f"   📄 Sample result length: {len(test_results[0].page_content)} characters")
    
except Exception as e:
    print(f"❌ Error loading vector database: {e}")
    raise

# Load metadata if available
try:
    with open('./processed_docs/rag_metadata.json', 'r') as f:
        rag_metadata = json.load(f)
    print(f"📋 Metadata loaded:")
    print(f"   Total chunks: {rag_metadata['chunking_strategy']['total_chunks']:,}")
    print(f"   Embedding dimension: {rag_metadata['embedding_info']['dimension']}")
    print(f"   Legal coverage: {rag_metadata['legal_content_analysis']['section_percentage']:.1f}% with sections")
except:
    print("⚠️  Metadata file not found, continuing without detailed stats")


In [None]:
## 3. Load Mistral Model for Generation


In [None]:
# Model configuration
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.1"

# Quantization config for memory efficiency (optional)
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

print(f"🔄 Loading Mistral model: {MODEL_NAME}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"✅ Tokenizer loaded")

# Load model with quantization
try:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        quantization_config=quantization_config,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True
    )
    print(f"✅ Mistral model loaded with 4-bit quantization")
    
    if torch.cuda.is_available():
        print(f"   🔧 GPU: {torch.cuda.get_device_name()}")
        print(f"   💾 Model device: {next(model.parameters()).device}")
    else:
        print(f"   💻 Running on CPU")
        
except Exception as e:
    print(f"❌ Error loading model: {e}")
    print("Trying without quantization...")
    
    # Fallback: Load without quantization
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,
        device_map="auto" if torch.cuda.is_available() else "cpu"
    )
    print(f"✅ Model loaded without quantization")

# Test model generation
test_prompt = "<s>[INST] What is a legal contract? [/INST]"
test_inputs = tokenizer(test_prompt, return_tensors="pt")

if torch.cuda.is_available():
    test_inputs = {k: v.to(model.device) for k, v in test_inputs.items()}

with torch.no_grad():
    test_output = model.generate(
        **test_inputs,
        max_new_tokens=50,
        do_sample=False,
        pad_token_id=tokenizer.eos_token_id
    )

test_response = tokenizer.decode(test_output[0], skip_special_tokens=True)
print(f"\n🧪 Model test successful:")
print(f"Response preview: {test_response[len(test_prompt):50]}...")


In [None]:
## 4. Implement RAG Pipeline


In [None]:
class LegalRAGSystem:
    """Complete RAG system for legal question answering"""
    
    def __init__(self, vectorstore, model, tokenizer, k_retrieve=5):
        self.vectorstore = vectorstore
        self.model = model
        self.tokenizer = tokenizer
        self.k_retrieve = k_retrieve
        
    def retrieve_documents(self, query: str, k: int = None) -> List[Document]:
        """Retrieve relevant documents for a query"""
        k = k or self.k_retrieve
        
        # Get relevant documents
        docs = self.vectorstore.similarity_search_with_score(query, k=k)
        
        # Sort by relevance score and return documents
        docs.sort(key=lambda x: x[1])  # Lower score = more similar
        return [doc for doc, score in docs]
    
    def create_context(self, documents: List[Document], max_context_length: int = 1500) -> str:
        """Create context from retrieved documents"""
        context_parts = []
        current_length = 0
        
        for doc in documents:
            content = doc.page_content
            
            # Add document with metadata
            doc_info = f"[Doc {doc.metadata.get('source_doc_id', 'N/A')}]"
            if doc.metadata.get('has_sections'):
                doc_info += " [Contains Legal Sections]"
            if doc.metadata.get('has_court_names'):
                doc_info += " [Court Document]"
            
            doc_text = f"{doc_info} {content}"
            
            # Check if adding this document would exceed the limit
            if current_length + len(doc_text) > max_context_length:
                # Truncate the last document if needed
                remaining_space = max_context_length - current_length
                if remaining_space > 100:  # Only add if substantial space remains
                    doc_text = doc_text[:remaining_space] + "..."
                    context_parts.append(doc_text)
                break
            
            context_parts.append(doc_text)
            current_length += len(doc_text)
        
        return "\n\n".join(context_parts)
    
    def generate_response(self, query: str, context: str, max_new_tokens: int = 512) -> str:
        """Generate response using Mistral with retrieved context"""
        
        # Create the prompt with retrieved context
        prompt = f"""<s>[INST] You are a legal AI assistant specializing in Indian law. Use the provided legal documents to answer the question accurately and comprehensively. Base your answer primarily on the given context.

Legal Documents:
{context}

Question: {query}

Please provide a detailed answer based on the legal documents provided above. [/INST]

Based on the legal documents provided, I can answer your question as follows:

"""
        
        # Tokenize the prompt
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
        
        # Move to same device as model
        if torch.cuda.is_available():
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
        
        # Generate response
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=self.tokenizer.eos_token_id,
                repetition_penalty=1.1,
                early_stopping=True
            )
        
        # Decode the response
        full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract only the generated part
        if "[/INST]" in full_response:
            response = full_response.split("[/INST]")[-1].strip()
        else:
            response = full_response
        
        return response
    
    def answer_question(self, query: str, k_retrieve: int = None, max_context_length: int = 1500, max_new_tokens: int = 512) -> Dict[str, Any]:
        """Complete RAG pipeline: retrieve + generate"""
        
        start_time = time.time()
        
        # Step 1: Retrieve relevant documents
        retrieved_docs = self.retrieve_documents(query, k_retrieve)
        
        # Step 2: Create context from retrieved documents
        context = self.create_context(retrieved_docs, max_context_length)
        
        # Step 3: Generate response
        response = self.generate_response(query, context, max_new_tokens)
        
        end_time = time.time()
        
        return {
            'query': query,
            'retrieved_docs': retrieved_docs,
            'context': context,
            'response': response,
            'retrieval_count': len(retrieved_docs),
            'context_length': len(context),
            'response_length': len(response),
            'processing_time': end_time - start_time
        }

# Initialize the RAG system
rag_system = LegalRAGSystem(
    vectorstore=vectorstore,
    model=model,
    tokenizer=tokenizer,
    k_retrieve=5
)

print("🚀 Legal RAG System Initialized Successfully!")
print(f"   📊 Retrieval: Top-{rag_system.k_retrieve} documents")
print(f"   🤖 Generation: {MODEL_NAME}")
print(f"   📝 Ready for legal question answering")


In [None]:
## 5. Test RAG System with Legal Questions


In [None]:
# Test questions for the RAG system
test_questions = [
    "What are the legal obligations of contractors in equipment agreements?",
    "How does the Bihar Sales Tax Act apply to machinery sales?",
    "What is the court's decision regarding contract disputes?",
    "What are the payment terms for equipment leasing agreements?",
    "What legal provisions govern the ownership of machinery?",
    "How are legal disputes between corporations and contractors resolved?",
    "What are the consequences of breaching equipment lease agreements?",
    "What role do consulting engineers play in legal agreements?"
]

print("🧪 Testing RAG System with Legal Questions")
print("=" * 80)

rag_results = []

for i, question in enumerate(test_questions):
    print(f"\n📝 Question {i+1}: {question}")
    print("-" * 60)
    
    try:
        # Get RAG response
        result = rag_system.answer_question(
            query=question,
            k_retrieve=5,
            max_context_length=1200,
            max_new_tokens=300
        )
        
        rag_results.append(result)
        
        # Display results
        print(f"🔍 Retrieved {result['retrieval_count']} documents")
        print(f"📄 Context length: {result['context_length']} characters")
        print(f"⏱️  Processing time: {result['processing_time']:.2f} seconds")
        
        print(f"\n🤖 RAG Response:")
        print(result['response'][:400] + "..." if len(result['response']) > 400 else result['response'])
        
        print(f"\n📚 Retrieved Document Sources:")
        for j, doc in enumerate(result['retrieved_docs'][:3]):
            doc_id = doc.metadata.get('source_doc_id', 'N/A')
            has_sections = "✓" if doc.metadata.get('has_sections') else "✗"
            has_courts = "✓" if doc.metadata.get('has_court_names') else "✗"
            print(f"   {j+1}. Doc {doc_id} | Sections: {has_sections} | Courts: {has_courts}")
        
    except Exception as e:
        print(f"❌ Error processing question: {e}")
        continue
    
    print("-" * 60)

print(f"\n✅ RAG System Testing Completed!")
print(f"📊 Successfully processed {len(rag_results)}/{len(test_questions)} questions")


In [None]:
## 6. Analyze RAG Performance


In [None]:
# Analyze RAG system performance
if rag_results:
    # Extract performance metrics
    processing_times = [r['processing_time'] for r in rag_results]
    context_lengths = [r['context_length'] for r in rag_results]
    response_lengths = [r['response_length'] for r in rag_results]
    retrieval_counts = [r['retrieval_count'] for r in rag_results]
    
    # Calculate statistics
    perf_stats = {
        'avg_processing_time': np.mean(processing_times),
        'std_processing_time': np.std(processing_times),
        'avg_context_length': np.mean(context_lengths),
        'avg_response_length': np.mean(response_lengths),
        'avg_retrieval_count': np.mean(retrieval_counts),
        'total_questions': len(rag_results)
    }
    
    print("📊 RAG Performance Analysis:")
    print("=" * 50)
    print(f"Questions processed: {perf_stats['total_questions']}")
    print(f"Average processing time: {perf_stats['avg_processing_time']:.2f} ± {perf_stats['std_processing_time']:.2f} seconds")
    print(f"Average context length: {perf_stats['avg_context_length']:.0f} characters")
    print(f"Average response length: {perf_stats['avg_response_length']:.0f} characters")
    print(f"Average documents retrieved: {perf_stats['avg_retrieval_count']:.1f}")
    
    # Visualize performance metrics
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('RAG System Performance Analysis', fontsize=16, fontweight='bold')
    
    # Processing time distribution
    axes[0, 0].hist(processing_times, bins=10, alpha=0.7, color='lightblue', edgecolor='black')
    axes[0, 0].axvline(np.mean(processing_times), color='red', linestyle='--', 
                       label=f'Mean: {np.mean(processing_times):.2f}s')
    axes[0, 0].set_title('Processing Time Distribution')
    axes[0, 0].set_xlabel('Time (seconds)')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].legend()
    
    # Context length distribution
    axes[0, 1].hist(context_lengths, bins=10, alpha=0.7, color='lightgreen', edgecolor='black')
    axes[0, 1].axvline(np.mean(context_lengths), color='red', linestyle='--',
                       label=f'Mean: {np.mean(context_lengths):.0f}')
    axes[0, 1].set_title('Context Length Distribution')
    axes[0, 1].set_xlabel('Characters')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].legend()
    
    # Response length distribution
    axes[1, 0].hist(response_lengths, bins=10, alpha=0.7, color='lightcoral', edgecolor='black')
    axes[1, 0].axvline(np.mean(response_lengths), color='red', linestyle='--',
                       label=f'Mean: {np.mean(response_lengths):.0f}')
    axes[1, 0].set_title('Response Length Distribution')
    axes[1, 0].set_xlabel('Characters')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].legend()
    
    # Processing time vs context length
    axes[1, 1].scatter(context_lengths, processing_times, alpha=0.7, color='purple')
    axes[1, 1].set_title('Processing Time vs Context Length')
    axes[1, 1].set_xlabel('Context Length (characters)')
    axes[1, 1].set_ylabel('Processing Time (seconds)')
    
    # Add correlation coefficient
    correlation = np.corrcoef(context_lengths, processing_times)[0, 1]
    axes[1, 1].text(0.05, 0.95, f'Correlation: {correlation:.3f}', 
                    transform=axes[1, 1].transAxes, fontsize=10,
                    bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
    
    plt.tight_layout()
    plt.show()
    
    # Analyze retrieval quality
    print(f"\n📚 Retrieval Quality Analysis:")
    docs_with_sections = 0
    docs_with_courts = 0
    total_retrieved_docs = 0
    
    for result in rag_results:
        for doc in result['retrieved_docs']:
            total_retrieved_docs += 1
            if doc.metadata.get('has_sections'):
                docs_with_sections += 1
            if doc.metadata.get('has_court_names'):
                docs_with_courts += 1
    
    print(f"   Total documents retrieved: {total_retrieved_docs}")
    print(f"   Documents with legal sections: {docs_with_sections} ({100*docs_with_sections/total_retrieved_docs:.1f}%)")
    print(f"   Documents with court references: {docs_with_courts} ({100*docs_with_courts/total_retrieved_docs:.1f}%)")
    print(f"   Average relevance: High (legal content focused)")
    
else:
    print("⚠️  No RAG results to analyze")


In [None]:
## 7. Save RAG Results for Comparison
