# Method 2: RAG (Retrieval-Augmented Generation) for NER Extraction

This notebook demonstrates NER extraction using RAG with vector database retrieval.

## Overview
- **Approach**: Retrieve similar examples from corpus to augment generation
- **Components**: 
  - Embedding Model: BAAI/bge-small-en-v1.5
  - Vector Store: FAISS
  - LLM: Meta-Llama-3.1-8B-Instruct
- **Advantages**: 
  - Better context understanding
  - Reduced hallucination
  - Learns from similar examples
- **Disadvantages**:
  - Requires corpus preparation
  - Slower inference (retrieval overhead)
  - Depends on embedding quality

## 1. Setup and Imports

In [None]:
import sys
sys.path.append('..')

from src.config import NERConfig, PROCESSED_DATA_DIR, RESULTS_DIR, MODELS_DIR
from src.data_loader import NERDataLoader
from src.rag_pipeline import RAGNERExtractor
from src.evaluation import NEREvaluator
from src.benchmark import NERBenchmark

import json
from pathlib import Path

## 2. Load Configuration

In [None]:
# Initialize configuration
config = NERConfig(
    model_name="meta-llama/Meta-Llama-3.1-8B-Instruct",
    embedding_model="BAAI/bge-small-en-v1.5",
    temperature=0.1,
    top_k_retrieval=3,  # Retrieve top 3 similar examples
    max_length=2048
)

print("Configuration:")
print(f"  LLM: {config.model_name}")
print(f"  Embedding Model: {config.embedding_model}")
print(f"  Top-K Retrieval: {config.top_k_retrieval}")
print(f"  Entity types: {config.entity_types}")

## 3. Load Dataset

In [None]:
# Load datasets
train_dataset = NERDataLoader.load_json_dataset(PROCESSED_DATA_DIR / "train.json")
val_dataset = NERDataLoader.load_json_dataset(PROCESSED_DATA_DIR / "validation.json")
test_dataset = NERDataLoader.load_json_dataset(PROCESSED_DATA_DIR / "test.json")

print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")
print(f"Test set size: {len(test_dataset)}")

## 4. Build RAG Index

In [None]:
# Initialize RAG extractor with training corpus
extractor = RAGNERExtractor(config=config, corpus=train_dataset)

print("\nRAG pipeline initialized!")
print(f"Corpus size: {len(train_dataset)} documents")

## 5. Save FAISS Index (Optional)

In [None]:
# Save index for reuse
index_path = MODELS_DIR / "rag_index.faiss"
extractor.save_index(index_path)
print(f"Index saved to {index_path}")

# To load later:
# extractor.load_index(index_path)

## 6. Test Retrieval on Examples

In [None]:
# Test retrieval
sample_text = val_dataset[0]['text']

print("Query Text:")
print(sample_text[:300] + "...\n")

retrieved = extractor.retrieve(sample_text, top_k=3)

print(f"\nRetrieved {len(retrieved)} similar examples:\n")
for i, doc in enumerate(retrieved, 1):
    print(f"\nExample {i} (score: {doc['retrieval_score']:.4f}):")
    print(f"Text: {doc['text'][:200]}...")
    print(f"Entities: {doc['entities']}")

## 7. Test on Sample Examples

In [None]:
# Test on a few examples
num_examples = 3

for i, sample in enumerate(val_dataset[:num_examples]):
    print(f"\n{'='*80}")
    print(f"Example {i+1}")
    print(f"{'='*80}")
    
    text = sample['text']
    ground_truth = sample['entities']
    
    print(f"\nText: {text[:300]}...\n")
    
    # Extract entities with RAG
    predicted = extractor.extract_entities(text)
    
    print("Ground Truth:")
    print(json.dumps(ground_truth, indent=2, ensure_ascii=False))
    
    print("\nPredicted:")
    print(json.dumps(predicted, indent=2, ensure_ascii=False))

## 8. Evaluate on Validation Set

In [None]:
# Run evaluation on validation set
print("Running evaluation on validation set...")
predictions, ground_truth = extractor.evaluate_on_dataset(val_dataset)

# Evaluate
evaluator = NEREvaluator(entity_types=config.entity_types)
results = evaluator.evaluate_all(predictions, ground_truth)

# Print results
evaluator.print_results(results)

# Save results
results_path = RESULTS_DIR / "rag_validation.json"
evaluator.save_results(results, results_path)
print(f"Results saved to {results_path}")

## 9. Run Benchmark on Test Set

In [None]:
# Run benchmark on test set
benchmark = NERBenchmark(config=config)
test_results = benchmark.run_benchmark(
    method_name="RAG",
    extractor=extractor,
    test_dataset=test_dataset,
    verbose=True
)

# Save benchmark results
benchmark.save_results(RESULTS_DIR / "rag")

## 10. Analysis and Insights

In [None]:
print("\nKey Insights:")
print(f"  - Exact Match Accuracy: {test_results['exact_match_accuracy']:.2%}")
print(f"  - Macro F1 Score: {test_results['partial_match_metrics']['macro_avg']['f1']:.2%}")
print(f"  - Inference Speed: {test_results['samples_per_second']:.2f} samples/second")

print("\nStrengths:")
print("  - Better accuracy with context from similar examples")
print("  - Reduced hallucination for rare entities")
print("  - Can adapt to domain by changing corpus")

print("\nWeaknesses:")
print("  - Slower than pure prompt engineering (retrieval overhead)")
print("  - Requires building and maintaining index")
print("  - Quality depends on corpus coverage")

## 11. Save Predictions for Analysis

In [None]:
# Save predictions
predictions_path = RESULTS_DIR / "rag" / "predictions.json"
benchmark.save_predictions(
    method_name="RAG",
    predictions=predictions,
    output_path=predictions_path
)

print("\nExperiment complete!")
print(f"Results saved to {RESULTS_DIR / 'rag'}")