In [1]:
! pip install transformers sentence-transformers faiss-cpu --quiet

In [2]:
! pip install datasets==3.6.0 --quiet

In [None]:
import torch
from transformers import (
    RagConfig,
    RagTokenizer,
    RagRetriever,
    RagSequenceForGeneration,
    RagTokenForGeneration
)
import warnings
warnings.filterwarnings('ignore')

print("✓ All imports successful")
print(f"Using device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

# ============================================================================
# PART 2: Initialize Components with RagConfig
# ============================================================================

# Option A: Load from pretrained (recommended for quick start)
print("\n" + "="*60)
print("Loading RAG Components...")
print("="*60)

# Initialize tokenizer
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
print("✓ Tokenizer loaded")

# Initialize retriever with dummy dataset for testing
# use_dummy_dataset=True allows quick testing without downloading large indices
retriever = RagRetriever.from_pretrained(
    "facebook/rag-token-nq",
    index_name="exact",
    use_dummy_dataset=True  # Set to False for production with real data
)
print("✓ Retriever loaded")

# Load configuration
config = RagConfig.from_pretrained("facebook/rag-token-nq")
print(f"✓ Config loaded - Retrieved docs per query: {config.n_docs}")
print("\n" + "="*60)
print("RagTokenForGeneration")
print("="*60)
print("Token-level generation: generates answers token by token")
print()

# Initialize generator model (retriever is abstracted away at this step)
model_token = RagTokenForGeneration.from_pretrained(
    "facebook/rag-token-nq",
    retriever=retriever
)
print("✓ RagTokenForGeneration model loaded")


questions_token = [
    "What is the capital of France?",
    "Who invented the telephone?",
    "When was Python programming language created?"
]

print("\nGenerating answers with RagTokenForGeneration...")
print("-" * 60)
for question in questions_token:
    # Tokenize input query
    input_dict = tokenizer(
        question,
        return_tensors="pt",
        padding=True,
        truncation=True
    )

    # Generate answer
    with torch.no_grad():
        generated = model_token.generate(
            input_ids=input_dict["input_ids"],
            attention_mask=input_dict.get("attention_mask"),
            num_beams=2, # beam search
            max_length=50,
            early_stopping=True
        )
    answer = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]

    print(f"Q: {question}")
    print(f"A: {answer}")
    print()