<a href="https://colab.research.google.com/github/datastax/ragstack-ai/blob/main/examples/notebooks/llama-astra.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# RAG with LlamaIndex and AstraDB

Build a RAG pipeline with RAGStack, AstraDB, and LlamaIndex.

## Prerequisites
You will need a vector-enabled Astra database.

Create an [Astra vector database](https://docs.datastax.com/en/astra-serverless/docs/getting-started/create-db-choices.html).
Within your database, create an [Astra DB Access Token](https://docs.datastax.com/en/astra-serverless/docs/manage/org/manage-tokens.html) with Database Administrator permissions.
Get your Astra DB Endpoint:
https://<ASTRA_DB_ID>-<ASTRA_DB_REGION>.apps.astra.datastax.com

See the [Prerequisites](https://docs.datastax.com/en/ragstack/docs/prerequisites.html) page for more details.

## Setup

In [None]:
! pip install ragstack-ai

In [None]:
import os
from getpass import getpass

# Enter your settings for Astra DB and OpenAI:
os.environ["ASTRA_DB_API_ENDPOINT"] = input("Enter your Astra DB API Endpoint: ")
os.environ["ASTRA_DB_APPLICATION_TOKEN"] = getpass("Enter your Astra DB Token: ")
os.environ["OPENAI_API_KEY"] = getpass("Enter your OpenAI API Key: ")

## Create RAG pipeline

### Embedding model and vector store

Load a sample dataset from Llama Hub into your Astra vector store.

In [None]:
from llama_index.core.llama_dataset import download_llama_dataset

!mkdir -p 'data'

dataset = download_llama_dataset(
  "PaulGrahamEssayDataset", "./data"
)

Load the documents from the dataset into memory.

In [None]:
from llama_index.vector_stores.astra_db import AstraDBVectorStore
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, StorageContext

documents = SimpleDirectoryReader("./data/source_files").load_data()
print(f"Total documents: {len(documents)}")
print(f"First document, id: {documents[0].doc_id}")
print(f"First document, hash: {documents[0].hash}")
print(
    "First document, text"
    f" ({len(documents[0].text)} characters):\n"
    f"{'=' * 20}\n"
    f"{documents[0].text[:360]} ..."
)

Create a vector store instance.

In [None]:
from llama_index.vector_stores.astra_db import AstraDBVectorStore
import os

astra_db_store = AstraDBVectorStore(
    token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"),
    api_endpoint=os.getenv("ASTRA_DB_API_ENDPOINT"),
    collection_name="test_llama",
    embedding_dimension=1536,
)

storage_context = StorageContext.from_defaults(vector_store=astra_db_store)

index = VectorStoreIndex.from_documents(
    documents, storage_context=storage_context
)

### Query

Query the index for the most relevant answer to your prompt, "Why did the author choose to work on AI?"

In [None]:
query_engine = index.as_query_engine()
query_string_1 = "Why did the author choose to work on AI?"
response = query_engine.query(query_string_1)

print(query_string_1)
print(response.response)

Use a retriever to retrieve results from your vector store index based on your prompt.

This will retrieve three nodes based on your prompt, and return the nodes with their relevance scores.

In [None]:
retriever = index.as_retriever(
    vector_store_query_mode="default",
    similarity_top_k=3,
)

nodes_with_scores = retriever.retrieve(query_string_1)

print(query_string_1)
print(f"Found {len(nodes_with_scores)} nodes.")
for idx, node_with_score in enumerate(nodes_with_scores):
    print(f"    [{idx}] score = {node_with_score.score}")
    print(f"        id    = {node_with_score.node.node_id}")
    print(f"        text  = {node_with_score.node.text[:90]} ...")

### MMR

Set the retriever to sort results by Maximal Marginal Relevance, or MMR, instead of the default similarity search.

Send the prompt again. The top result is the most relevant (positive number), while the other results are the least relevant (negative numbers).

In [None]:
retriever = index.as_retriever(
    vector_store_query_mode="mmr",
    similarity_top_k=3,
    vector_store_kwargs={"mmr_prefetch_factor": 4},
)

nodes_with_scores = retriever.retrieve(query_string_1)

print(query_string_1)
print(f"Found {len(nodes_with_scores)} nodes.")
for idx, node_with_score in enumerate(nodes_with_scores):
    print(f"    [{idx}] score = {node_with_score.score}")
    print(f"        id    = {node_with_score.node.node_id}")
    print(f"        text  = {node_with_score.node.text[:90]} ...")

## Cleanup

In [None]:
# WARNING: This will delete the collection and all documents in the collection
# astra_db_store.delete_collection()