## Imports

In [None]:
!pip install torch transformers bitsandbytes accelerate sqlparse

In [None]:
import torch
import re
import pandas as pd
import sqlite3
import warnings
import sqlparse
from IPython.display import display, Markdown
from transformers import AutoTokenizer, AutoModelForCausalLM

warnings.filterwarnings("ignore")

## Download the model

We are using Defog's LLama-3-based SQLCoder-8B, which was designed as "a state-of-the-art model for generating SQL queries from natural language".

https://huggingface.co/defog/llama-3-sqlcoder-8b

In [None]:
model_name = "defog/llama-3-sqlcoder-8b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
text2sql_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    load_in_4bit=True,
    device_map="auto",
    use_cache=True,
)

## Load mockup travel data

We have prepared some mockup travel data for travel durations, modes of transportation and cost from Athens to a selection of Greek islands. We load this data into memory, simulating an SQL database. You can have a look at the function if you want to see the structure of the data.



In [None]:
%run mockup_data.py
conn = sqlite3.connect('data/travel_information.sqlite')
available_islands = add_available_islands()

## Generate database description

We need to supply information such as the schema of the database to the model that will create the SQL statement for us based on the request we formulate.

In [None]:
TASK = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
"""

INSTRUCTIONS = """### Instructions
- If you cannot answer the question with the database schema provided below, return 'I do not know'.
- Always return every attribute.

"""

DATABASET_SCHEMA = """### Database Schema
This query will run on a database and the database schema is represented in this string:

CREATE TABLE routes (
  route_id integer primary key,
  origin_city varchar(50),
  destination_city varchar(50),
  distance_km decimal(10,2),
  travel_time_ferry decimal(10,2),
  travel_time_plane decimal(10,2),
  travel_time_sailboat decimal(10,2),
  travel_time_speedboat decimal(10,2),
  price_chf_ferry decimal(10,2),
  price_chf_plane decimal(10,2),
  price_chf_sailboat decimal(10,2),
  price_chf_speedboat decimal(10,2)
);
-- routes.origin_city is always Athens
-- routes.destination_city is the target city of the travel. It is always written with a capital letter.
"""

ANSWER = """### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]"""

database_description = TASK + INSTRUCTIONS + DATABASET_SCHEMA + ANSWER

display(Markdown(database_description))


## Generate the SQL query

Using the description of the database and the instructions we crafted above we can now request the model to create an SQL statement for us. 

In [None]:
def create_request_for_text2sql(island: str) -> str:
  """Function to create a request for an island."""
  request = "Give me the all the available travel information from Athens to " + island
  return request

def clean_generated_sql(generated_sql: str) -> str:
  """Function to clean the generated SQL to return only relevant SQL"""
  pattern = r"SELECT[\s\S]+?;"
  match = re.search(pattern, generated_sql)
  return match.group(0)

def generate_sql_query(question: str) -> str:
  """Using the LLM we selected, we retrieve the query for the request we created."""
  updated_prompt = database_description.format(question=question)
  inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
  generated_ids = text2sql_model.generate(
      **inputs,
      num_return_sequences=1,
      eos_token_id=tokenizer.eos_token_id,
      pad_token_id=tokenizer.eos_token_id,
      max_new_tokens=400,
      do_sample=False,
      num_beams=1,
  )
  outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

  torch.cuda.empty_cache()
  torch.cuda.synchronize()

  generated_sql = sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

  return clean_generated_sql(generated_sql)

Select an island to ask questions for. Make sure that the island is in the list of known islands (or try to ask for an island that does not exist and see what happens).

In [None]:
island = "Crete"

In [None]:
text2sql_query = create_request_for_text2sql(island=island)

print(f"Request is: {text2sql_query}")

sql_query = generate_sql_query(text2sql_query)
display(Markdown(sql_query))

## Use the generated SQL to fetch information from the database

We have everything together now to fetch data from the database using the SQL that was created for us.

Note: In case you dropped the connection to the in memory database, you need to re-execute the commented out lines.

In [None]:
def fetch_data_from_database(sql_query: str, conn: sqlite3.Connection) -> pd.DataFrame:
  """Function to fetch data from the in memory database."""
  df = pd.read_sql_query(sql_query, conn)

  return df

# Execute these lines in case your connection to the in memory databaset dropped
# try:
#   conn.close()
#   conn = sqlite3.connect('data/travel_information.sqlite')
#   add_mockup_data_to_memory(conn=conn)
# except:
#   pass



data = fetch_data_from_database(sql_query=sql_query, conn=conn)
data.T

In [None]:
conn.close()

## Full Code Example

If you want to play around with different islands and check the generated SQL, use below cells for your convenience which includes all the functions we have defined above.

In [None]:
island = "Karpathos"

In [None]:
%run mockup_data.py
text2sql_query = create_request_for_text2sql(island=island)

print(f"Request is: {text2sql_query}")

sql_query = generate_sql_query(text2sql_query)

print(f"SQL query is: {sql_query}")

conn = sqlite3.connect('data/travel_information.sqlite')
available_islands = add_available_islands()

TASK = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
"""

INSTRUCTIONS = """### Instructions
- If you cannot answer the question with the database schema provided below, return 'I do not know'.
- Always return every attribute.

"""

DATABASET_SCHEMA = """### Database Schema
This query will run on a database and the database schema is represented in this string:

CREATE TABLE routes (
  route_id integer primary key,
  origin_city varchar(50),
  destination_city varchar(50),
  distance_km decimal(10,2),
  travel_time_ferry decimal(10,2),
  travel_time_plane decimal(10,2),
  travel_time_sailboat decimal(10,2),
  travel_time_speedboat decimal(10,2),
  price_chf_ferry decimal(10,2),
  price_chf_plane decimal(10,2),
  price_chf_sailboat decimal(10,2),
  price_chf_speedboat decimal(10,2)
);
-- routes.origin_city is always Athens
-- routes.destination_city is the target city of the travel. It is always written with a capital letter.
"""

ANSWER = """### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]"""

database_description = TASK + INSTRUCTIONS + DATABASET_SCHEMA + ANSWER

def create_request_for_text2sql(island: str) -> str:
  """Function to create a request for an island."""
  request = "Give me the all the available travel information from Athens to " + island
  return request

def clean_generated_sql(generated_sql: str) -> str:
  """Function to clean the generated SQL to return only relevant SQL"""
  pattern = r"SELECT[\s\S]+?;"
  match = re.search(pattern, generated_sql)
  return match.group(0)

def generate_sql_query(question: str) -> str:
  """Using the LLM we selected, we retrieve the query for the request we created."""
  updated_prompt = database_description.format(question=question)
  inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
  generated_ids = text2sql_model.generate(
      **inputs,
      num_return_sequences=1,
      eos_token_id=tokenizer.eos_token_id,
      pad_token_id=tokenizer.eos_token_id,
      max_new_tokens=400,
      do_sample=False,
      num_beams=1,
  )
  outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

  torch.cuda.empty_cache()
  torch.cuda.synchronize()

  generated_sql = sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

  return clean_generated_sql(generated_sql)

def fetch_data_from_database(sql_query: str, conn: sqlite3.Connection) -> pd.DataFrame:
  """Function to fetch data from the in memory database."""
  df = pd.read_sql_query(sql_query, conn)

  return df

fetch_data_from_database(sql_query=sql_query, conn=conn).T