# RAG backed by SQL and Jina Reranker

In this example, we will make a simple RAG system that draws on an SQL database instead of drawing information from a document store.

Steps:
- Given an SQL database, we extract SQL table definitions (the `CREATE` line in an SQL dump) and store them. Here, the definitions are stored in memory as a list. If we want to scale up, we may require more sophisticated storage.
- Users enter a query in natural language.
- [`jinaai/jina-reranker-v2-base-multilinguial`](https://huggingface.co/jinaai/jina-reranker-v2-base-multilingual), an SQL-aware reranking model from [Jina AI](https://jina.ai/), sorts the table definitions in order of their relevance to the user's query.
- We present [`mistralai/Mistral-7B-Instruct-v0.1`](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) with a prompt containing the user's query and the top three table definitions, with a request to write an SQL query to fit the task.
- Mistral Instruct generates an SQL query and we run it against the database, retrieving a result.
- The SQL query result is converted to JSON and presented to Mistral Instruct in a new prompt, along with the user's original query, the SQL query, and a request to compose an answer for the user in natural language.
- Mistral Instruct's natural language text response is returned to the user.

**Database**:

In this example, we will use a small open-access database of [video game sales records](https://github.com/bbrumm/databasestar/tree/main/sample_databases/sample_db_videogames/sqlite) with the help of SQLite.

**Requirements**:

We will run the Jina Reranker v2 model locally with GPU support. We will also use the LlamaIndex RAG framework and the HuggingFace Inference API to access Mistral 7B Instruct v0.1.


## Setups

In [None]:
!pip install -qU transformers einops llama-index llama-index-postprocessor-jinaai-rerank  llama-index-llms-huggingface "huggingface_hub[inference]"

### Download the database

In [None]:
!wget https://github.com/bbrumm/databasestar/raw/main/sample_databases/sample_db_videogames/sqlite/videogames.db

### Downlad and run Jina Reranker v2

In [None]:
from transformers import AutoModelForSequenceClassification

reranker_model = AutoModelForSequenceClassification.from_pretrained(
    'jinaai/jina-reranker-v2-base-multilingual',
    torch_dtype='auto',
    trust_remote_code=True
).to('cuda')
reranker_model.eval()

### Set up the Mistral Instruct

We will use LlamaIndex to create a holder object for the connection to the HuggingFace inference API.

In [None]:
from google.colab import userdata

hf_token = userdata.get('HF_TOKEN')

In [None]:
from llama_index.llms.huggingface improt HuggingFaceInferenceAPI

mistral_llm = HuggingFaceInfereceAPI(
    model_name='mistralai/Mistral-8x7B-Instruct-v0.1',
    token=hf_token
)

## Using SQL-Aware Jina Reranker v2

We extracted the eight table definitions from the database import files located on GitHub.

In [None]:
table_declarations = [
    "CREATE TABLE platform (\n\tid INTEGER PRIMARY KEY,\n\tplatform_name TEXT DEFAULT NULL\n);",
    "CREATE TABLE genre (\n\tid INTEGER PRIMARY KEY,\n\tgenre_name TEXT DEFAULT NULL\n);",
    "CREATE TABLE publisher (\n\tid INTEGER PRIMARY KEY,\n\tpublisher_name TEXT DEFAULT NULL\n);",
    "CREATE TABLE region (\n\tid INTEGER PRIMARY KEY,\n\tregion_name TEXT DEFAULT NULL\n);",
    "CREATE TABLE game (\n\tid INTEGER PRIMARY KEY,\n\tgenre_id INTEGER,\n\tgame_name TEXT DEFAULT NULL,\n\tCONSTRAINT fk_gm_gen FOREIGN KEY (genre_id) REFERENCES genre(id)\n);",
    "CREATE TABLE game_publisher (\n\tid INTEGER PRIMARY KEY,\n\tgame_id INTEGER DEFAULT NULL,\n\tpublisher_id INTEGER DEFAULT NULL,\n\tCONSTRAINT fk_gpu_gam FOREIGN KEY (game_id) REFERENCES game(id),\n\tCONSTRAINT fk_gpu_pub FOREIGN KEY (publisher_id) REFERENCES publisher(id)\n);",
    "CREATE TABLE game_platform (\n\tid INTEGER PRIMARY KEY,\n\tgame_publisher_id INTEGER DEFAULT NULL,\n\tplatform_id INTEGER DEFAULT NULL,\n\trelease_year INTEGER DEFAULT NULL,\n\tCONSTRAINT fk_gpl_gp FOREIGN KEY (game_publisher_id) REFERENCES game_publisher(id),\n\tCONSTRAINT fk_gpl_pla FOREIGN KEY (platform_id) REFERENCES platform(id)\n);",
    "CREATE TABLE region_sales (\n\tregion_id INTEGER DEFAULT NULL,\n\tgame_platform_id INTEGER DEFAULT NULL,\n\tnum_sales REAL,\n   CONSTRAINT fk_rs_gp FOREIGN KEY (game_platform_id) REFERENCES game_platform(id),\n\tCONSTRAINT fk_rs_reg FOREIGN KEY (region_id) REFERENCES region(id)\n);",
]

Next we define a function that takes a natural language query and the list of table definitions, scores all of them with Jina Reranker v2, returning them in order from highest scoring to lowest.

In [None]:
from typing import List, Tuple

def rank_tables(query: str, table_specs: List[str], top_n: int = 0) -> List[Tuple[float, str]]:
    """Get sorted pairs of scores and table specifications, then return the top N
    of all if top_n is 0 by default
    """
    pairs = [[query, table_spec] for table_spec in table_specs]
    scores = reranker_model.compute_score(pairs)
    scored_tables = [
        (score, table_spec) for score, table_spec in zip(scores, table_specs)
    ]
    scored_tables.sort(key=lambda x: x[0], reverse=True)

    if top_n and top_n < len(scored_tables):
        return scored_tables[0:top_n]
    return scored_tables

Jina Reranker v2 scores every table definition we give it and by default this function will return all of them with their scores.

In [None]:
user_query = "Identify the top 10 platforms by total sales"

In [None]:
ranked_tables = rank_tables(user_query, table_declarations, top_n=3)

The output contains the tables `region_sales`, `platform`, and `game_platform`.

## Using Mistral Instruct to generate SQL

We will have Mistral Instruct v0.1 write an SQL query that fulfils the user's query, based on the declarations of the top three tables according to the reranker.

First, we need to make a prompt for that purpose using LlamaIndex's `PromptTemplate` class.

In [None]:
from llama_index.core import PromptTemplate

make_sql_prompt_tmpl_text = '''
Generate a SQL query to answer the following question from the user:
"{query_str}"

The SQL query should use only tables with the following SQL definitions:

Table 1:
{table_1}

Table 2:
{table_2}

Table 3:
{table_3}

Make sure you ONLY output an SQL query and no explanation.
'''

make_sql_prompt_tmpl = PromptTemplate(make_sql_prompt_tmpl_text)

In [None]:
make_sql_prompt = make_sql_prompt_tmpl.format(
    query_str=user_query,
    table_1=ranked_tables[0][1],
    table_2=ranked_tables[1][1],
    table_3=ranked_tables[2][1]
)

print(make_sql_prompt)

In [None]:
response = mistral_llm.complete(make_sql_prompt)
sql_query = str(response)
print(sql_query)

## Running the SQL query

We will use `SQLite` to run the query above against the database `videogames.db`

In [None]:
import sqlite3

con = sqlite3.connect('videogames.db')
cur = con.cursor()
sql_response = cur.execute(sql_query).fetchall()

In [None]:
sql_response

## Getting a natural language answer

Now we will pass the user's query, the SQL query, and the result back to Mistral Instruct with a new prompt template.

In [None]:
rag_prompt_tmpl_str = """
Use the information in the JSON table to answer the following user query.
Do not explain anything, just answer concisely. Use natural language in your
answer, not computer formatting.

USER QUERY: {query_str}

JSON table:
{json_table}

This table was generated by the following SQL query:
{sql_query}

Answer ONLY using the information in the table and the SQL query, and if the
table does not provide the information to answer the question, answer
"No Information".
"""

rag_prompt_tmpl = Prompt_template(rag_prompt_tmpl_str)

We will convert the SQL output into JSON so that Mistral Instruct can understand.

In [None]:
import json

user_query = "Identify the top 10 platforms by total sales"

rag_prompt = rag_prompt_tmpl.format(
    query_str=user_query,
    json_table=json.dumps(sql_response),
    sql_query=sql_query
)

In [None]:
print(rag_prompt)

In [None]:
rag_response = mistral_llm.complete(rag_prompt)

print(str(rag_response))

## Combining into one function

In [None]:
def answer_sql(user_query: str) -> str:
    try:
        ranked_tables = rank_tables(query_query, table_declarations, top_n=3)
    except Exception as e:
        print(f"Ranking failed.\nUser query:\n{user_query}\n\n")
        raise e

    make_sql_prompt = make_sql_prompt_tmpl.format(
        query_str=user_query,
        table_1=ranked_tables[0][1],
        table_2=ranked_tables[1][1],
        table_3=ranked_tables[2][1]
    )

    try:
        response = mistral_llm.complete(make_sql_prompt)
    except Exception as e:
        print(f"SQL query generation failed.\nPrompt:\n{make_sql_prompt}\n\n")
        raise e

    # Backslash removal is a necessary hack because sometimes
    # Mistral puts them in its generated code.
    sql_query = str(response).replace('\\', '')

    try:
        con = sqlite3.connect('videogames.db')
        cur = con.cursor()
        sql_response = cur.execute(sql_query).fetchall()
    except Exception as e:
        print(f"SQL querying failed. Query:\n{sql_query}\n\n")
        raise e

    rag_prompt = rag_prompt_tmpl.format(
        query_str=user_query,
        json_table=json.dumps(sql_response),
        sql_query=sql_query
    )
    try:
        rag_response = mistral_llm.complete(rag_prompt)
        return str(rag_response)
    except Exception as e:
        print(f"Answer generation failed. Prompt:\n{rag_prompt}\n\n")
        raise e

**Testing**...

In [None]:
print(answer_sql("Identify the top 10 platforms by total sales."))

In [None]:
print(answer_sql("Summarize sales by region."))

In [None]:
print(answer_sql("List the publisher with the largest number of published games."))

In [None]:
print(answer_sql("Display the year with most games released."))

In [None]:
print(answer_sql("What is the most popular game genre on the Wii platform?"))

In [None]:
print(answer_sql("What is the most popular game genre of 2012?"))