[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dbamman/anlp25/blob/main/10.llms/RAG.ipynb)


# Retrieval Augmented Generation

In this notebook, we will implement a simple RAG system.

Concretely, we will begin by building a document embedding collection. Then for each query, we:
1. Embed the query in that same space
2. Use FAISS to retrieve the $n$ closest documents.
3. Given those retrieved documents, we'll then incorporate them into the context of a prompt for an LLM.

In [None]:
!pip install sentence-transformers

# install faiss for gpu
!pip install faiss-gpu-cu12

In [None]:
import torch
import operator

import faiss
import nltk
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

nltk.download("punkt")
nltk.download("punkt_tab")

In [None]:
# Run this early to let the model load!

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B", device_map="cuda", dtype="auto")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")

We'll use the ACL paper abstracts you worked with before as our set of documents.

In [None]:
!wget https://raw.githubusercontent.com/dbamman/anlp25/main/data/acl.all.tsv

In [None]:
df = pd.read_csv("acl.all.tsv", sep="\t", names=["cite", "year", "title", "abstract"])

## Building an index

We must decide on a document embedding model; even within the SentenceTransformer family, there many pre-trained models that vary by accuracy, size, etc. See [here](https://www.sbert.net/docs/sentence_transformer/pretrained_models.html) for a list of all models.

In particular, some models are trained for question answering tasks instead of strict semantic similarity; this means that questions will be placed in a similar region to their relevant answers.

In [None]:
# we'll normalize all vectors so that cosine similarity reduces to a dot product
# (enabling the use of inner product as a simliarity metric)
def normalize(matrix):
    row_norms = np.linalg.norm(matrix, axis=1, keepdims=True)
    normalized_rows = matrix / row_norms
    return normalized_rows

class Index():
    def __init__(self, model_name, texts):
        self.encoder = SentenceTransformer(model_name)

        doc_embeddings = self.encoder.encode(texts)
        doc_embeddings = normalize(doc_embeddings)
        num_docs, embedding_size = doc_embeddings.shape

        # Our dataset is small enough that we can use exact search, so we'll use IndexFlatIP
        # (which builds an exact index with doc product as the similarity metric)
        self.index = faiss.IndexFlatIP(embedding_size)
        self.index.add(doc_embeddings)

        # If you want to use faster but approximate search over a larger dataset, use this
        # self.index = faiss.IndexFlatIP(embedding_size)
        # self.index = faiss.IndexIVFFlat(index, embedding_size, 10, faiss.METRIC_INNER_PRODUCT)
        # self.index.train(doc_embeddings)
        # self.index.add(doc_embeddings)

    def query(self, query, n=3):
        query_embedding = self.encoder.encode([query])
        query_embedding = normalize(query_embedding)
        distances, indices = self.index.search(query_embedding, n)
        return distances[0], indices[0]


RAG depends on having a good retriever since we condition *only* on the documents (and passages of documents) that are retrieved as being relevant. Here, we experiment with two different embedding models:

1. `all-mpnet-base-v2` is the default sentence-similarity model in sentence-transformers
2. `multi-qa-mpnet-base-dot-v1` is trained on question/answer pairs

**Consider**: why might we want to train on question/answer pairs?

In [None]:
mpnet_index = Index("sentence-transformers/all-mpnet-base-v2", df.abstract)
multiqa_index = Index("sentence-transformers/multi-qa-mpnet-base-dot-v1", df.abstract)

Now let's embed our query in the same representation space and find the documents most similar to it. Which embedding index do you think provides better results?

In [None]:
query = "What was the CONLL 2018 shared task?"

distances, indices = mpnet_index.query(query)
for dist, idx in zip(distances, indices):
  print("%.3f\t%s (%d)\t%s" % (dist, df.title[idx], df.year[idx], df.abstract[idx][:150]))

In [None]:
distances, indices = multiqa_index.query(query)
for dist, idx in zip(distances, indices):
  print("%.3f\t%s (%d)\t%s" % (dist, df.title[idx], df.year[idx], df.abstract[idx][:150]))

## Generate

Now that we've built a retriever, let's now incorporate those retrieved passages into the context of a prompt to answer our initial query.

In [None]:
from textwrap import dedent
def format_passage(data, idx):
    """
    Generates formatted paper information for a given paper index.
    """
    title = data.title.iloc[idx]
    abstract = data.abstract.iloc[idx]
    year = data.year.iloc[idx]
    cite = data.cite.iloc[idx]

    return f"""
    Title: {title}
    Year: {year}
    Cite-key: {cite}
    Abstract: {abstract}
    """

In [None]:
def prompt_model(messages, thinking=False):
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=thinking # Switches between thinking and non-thinking modes. Default is True.
    )
    
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    # conduct text completion
    generated = model.generate(
        **model_inputs,
        max_new_tokens=500
    )

    # let's break this down:
    #                      | we take the element of the batch (our batch size is 1)
    #                      |  |-----------------------------| skip our original input
    output_ids = generated[0][len(model_inputs.input_ids[0]):].tolist()

    # decode into token space
    return tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")


def generate_without_rag(question, thinking=False):
    newline = "\n"
    system_prompt = dedent(f"""
        You're a helpful assistant for question answering.
    """).strip()

    rag_prompt = dedent(f"""
        Question: {question}
    """).strip()

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": rag_prompt},
    ]
    return prompt_model(messages)

def generate_with_rag(question, index, passages = None, thinking=False, show_rag=False):
    newline = "\n"
    system_prompt = dedent(f"""
        You're a helpful assistant for question answering. Use the information from the included passages to construct your response.
        In your response, only reference the passages with a parenthetical citation of the cite-key. Do not refer to the passages any other way.
    """).strip()

    if passages is None:
        passages = []
        distances, indices = index.query(question)
        for dist, idx in zip(distances, indices):
            passages.append(format_passage(df, idx))

    rag_prompt = dedent(f"""
        Question: {question}
    
        Relevant passages to the question:

        {passages[0]}
        
        {passages[1]}
        
        {passages[2]}
    """).strip()

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": rag_prompt},
    ]
    output = prompt_model(messages)
    if show_rag:
        output = f"{output}\n\nRAG prompt:{rag_prompt}"
    return output

Now let's generate responses. First, we try generating a response without any context (relying only on the model's pretraining). Then, we try querying with the `multiqa` and `all` embedding indices.

Here is a [list of CONLL shared tasks](https://www.conll.org/previous-tasks) over the years. Are the outputs accurate? Which do you prefer?

In [None]:
query = "What was the CONLL 2018 shared task?"

In [None]:
print(generate_without_rag(query))

In [None]:
print(generate_with_rag(query, multiqa_index))

In [None]:
print(generate_with_rag(query, mpnet_index))