# Ben Needs a Friend - Retrieval Augmented Generation (RAG)
This is part of the "Ben Needs a Friend" tutorial. See all the notebooks and materials [here](https://github.com/bpben/ben_friend_25). Follow setup instructions there to use this notebook.

In this notebook, we set up an approach to use a set of documents ("memories") in a Retrieval Augmented Generation (RAG) workflow.

This notebook is intended to be run in Kaggle Notebooks with GPU acceleration.  Access that version [here](https://www.kaggle.com/code/bpoben/ben-needs-a-friend-rag). 

If you want to run this locally, edit the `model_name` path.  Note that this assumes use of GPUs, it may be slow or not work at all if you do not have access to GPUs.

In [None]:
from llamabot import SimpleBot, StructuredBot, ChatBot
import json
from pydantic import BaseModel
import tempfile

sft_model = "qwen2.5:1.5b"

### Vector stores
The first part of RAG is "retrieval".  To do that we essentially need to create a mechanism for the model to retrieve relevant information.  One approach is to create a set of "embeddings" for our each memory I have with my AI friend that can be compared against the input prompt.

#### LanceDB implementation
One approach to setting up this vector store is to use [LanceDB's implementation of embedding](https://lancedb.github.io/lancedb/embeddings/embedding_functions/).  Llamabot uses this by default.  Below is an overview of what happens under the hood, but we'll just rely on Llamabot's implementation.

In [None]:
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import get_registry

# create a database
db = lancedb.connect("/tmp/db")
db.drop_all_tables()
# initialize a default sentence-transformers model (paraphrase-MiniLM-L6-v2)
model = get_registry().get("sentence-transformers").create()

# specify a schema (just text + vector)
class Words(LanceModel):
    text: str = model.SourceField()
    vector: Vector(model.ndims()) = model.VectorField()


try:
    table = db.create_table("words", schema=Words)
except ValueError:
    table = db.open_table("words")

# add in some entries
table.add(
    [
        {"text": "hello world"},
        {"text": "goodbye world"}
    ]
)


In [None]:
# look at the entries
table.head()

In [None]:
query = "greetings"
search_query = table.search(query)
search_query._query[:10]

In [None]:
# get a single (most similar) result, translate it into the pydantic model
search_query.limit(1).to_pydantic(Words)[0].text

In [None]:
query = "farewell"
result = table.search(query).limit(1).to_pydantic(Words)[0]
print(result.text)

In [None]:
table.search(query).limit(1).to_pydantic(Words)

In [None]:
query = "random word"
result = table.search(query).limit(1).to_pydantic(Words)[0]
print(result.text)

Llamabot provides a class called `QueryBot` which implements everything above for you and allows you to just query the vector database.

So let's first write some "memories" as documents:

In [None]:
memories = ['Ben is really bad at video games, but Friend is excellent.',
       'Friend is a pro skiier, but Ben is terrified.',]

memory_filenames = []

for i, m in enumerate(memories):
    # write a temporary file
    temp_file = tempfile.NamedTemporaryFile(
        delete=False, mode='w', suffix=f'_memory_{i}.txt')
    with open(temp_file.name, "w") as f:
        f.write(m)
    print(f"Memory {i} written to {temp_file.name}")
    memory_filenames.append(temp_file.name)

In [None]:
from llamabot import QueryBot
from pathlib import Path

friend_prompt = """Your name is Friend.  \
You are having a conversation with your close friend Ben. \
You and Ben are sarcastic and poke fun at one another. \
But you care about each other and support one another."""

query_completer = QueryBot(
    system_prompt=friend_prompt,
    model_name=f"ollama_chat/{sft_model}",
    collection_name="memories",
    document_paths=memory_filenames
)

# # note - you'll want to reset the collection 
# # if you want to replace existing memories
#query_completer.docstore.reset()


In [None]:
query = "Remember that time we played video games?"
print("Retrieved memory: ", 
      query_completer.docstore.retrieve(query, 1))

response = query_completer(query,
                n_results=1)

In [None]:
query = "Remember when we went skiing?"
print("Retrieved memory: ", 
      query_completer.docstore.retrieve(query, 1))

response = query_completer(query,
                n_results=1)