# MLX inferrence - Text to SQL to extract chart data

> For this model, we don't need to fine tune the model since there are pre-trained models available for text to SQL. We are gonna use `LLM selection` and `Prompt Engineering` to select the best model and prompt that fits our use case in generating the right SQL for extracting the data that can be represented in the chart. 

## Prerequisites

In [None]:
%%capture
# Install MLX-LM framework for fine-tuning LLMs using Apple Silicon GPU
%pip install mlx-lm

# Install Hugging Face Hub to download models from the Hugging Face Hub
%pip install huggingface_hub

## LLM Selection

During the selection process, we found that [defog/llama-3-sqlcoder-8b](https://huggingface.co/defog/llama-3-sqlcoder-8b) is the most capable model for the task of text to SQL based on our requirements

## Prompt Engineering

In this step, we are gonna construct the prompt that can be used to generate the SQL query. In order to integrate the result of the SQL with the application, we need to provide to our model the following information:

### 1. The main task

#### Example: 
```text
Generate a SQL query to answer the question: "What is the total number of sales for each product?"
```

### 2. The instructions

Specify the instructions for the model to generate the SQL query. For example, most of the chart data is based on some aggregate functions and the result set consists of two columns, one for the label and the other for the value. In order to extract the data easily, we can have a convention that the aggregate function has always the alias 'value' and the group by column has always the alias 'label'.

#### Example: 

```text
- If you cannot answer the question with the available database schema, return 'I do not know'
- The aggregate function has always the alias 'value'
- The group by column has always the alias 'label'
...
```

### 3. The database schema

Provide the database schema to the model to generate the SQL query. The schema should be in the following format and will be provided at runtime to the model

#### Example: 

```sql
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

...

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
```

In [3]:
user_question = "total number of sales people by region in a bar chart"
instructions = """
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that revenue is price multiplied by quantity
- Remember that cost is supply_price multiplied by quantity
- The aggregate function has always the alias 'value' d
- The group by column has always the alias 'label'
"""
ddl_statements = """
CREATE TABLE products (
  product_id INTEGER PRIMARY KEY, -- Unique ID for each product
  name VARCHAR(50), -- Name of the product
  price DECIMAL(10,2), -- Price of each unit of the product
  quantity INTEGER  -- Current quantity in stock
);

CREATE TABLE customers (
   customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
   name VARCHAR(50), -- Name of the customer
   address VARCHAR(100) -- Mailing address of the customer
);

CREATE TABLE salespeople (
  salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
  name VARCHAR(50), -- Name of the salesperson
  region VARCHAR(50) -- Geographic sales region
);

CREATE TABLE sales (
  sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
  product_id INTEGER, -- ID of product sold
  customer_id INTEGER,  -- ID of customer who made purchase
  salesperson_id INTEGER, -- ID of salesperson who made the sale
  sale_date DATE, -- Date the sale occurred
  quantity INTEGER -- Quantity of product sold
);

CREATE TABLE product_suppliers (
  supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
  product_id INTEGER, -- Product ID supplied
  supply_price DECIMAL(10,2) -- Unit price charged by supplier
);

-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id  
"""

prompt = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Generate a SQL query to answer this question: `{user_question}`.
{instructions}

DDL statements:
{ddl_statements}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{user_question}`:
```sql
"""

prompt = prompt.format(user_question=user_question, instructions=instructions, ddl_statements=ddl_statements)

## Inference

In [None]:
from mlx_lm import load, generate

model_id = 'defog/llama-3-sqlcoder-8b'
model, tokenizer = load(model_id)

temperature = 0.0 # Recommended temperature for SQL generation
response = generate(model, tokenizer, prompt=prompt, max_tokens=2048, verbose=True, temp=temperature)

## Saving the model

### GGUF format

In [None]:
from huggingface_hub import snapshot_download

format = "q8_0" # f16, f32, b16, q8_0

# Download the model from the Hugging Face Hub
snapshot_download(model_id, local_dir="sqlcoder", local_dir_use_symlinks=False, revision="main")

# Convert the fused model to GGUF format using llama.cpp
!git clone https://github.com/ggerganov/llama.cpp.git && cd llama.cpp && python3 convert_hf_to_gguf.py ../sqlcoder --outfile ../chart-data.gguf --outtype {format}

print("Conversion to GGUF format is complete.")
!rm -rf llama.cpp
!rm -rf sqlcoder