#SQLCoder
Run the cells below to run inference on our text-to-SQL LLM – SQLCoder.


##Setup

In [7]:
!pip install torch transformers bitsandbytes accelerate



In [8]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

In [9]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"
!pip install sqlvalidator



In [10]:
import sqlite3
import sqlvalidator

In [11]:
torch.cuda.is_available()

True

##Download the Model
Use an A100 on Colab Pro (or any system with >30GB VRAM on your own machine) to load this in bf16. If unavailable, use a GPU with minimum 20GB VRAM to load this in 8bit, or with minimum 12GB of VRAM to load in 4bit. On Colab, it works with a V100 but crashes on a T4.

Downloading the model and then loading it to memory step takes around 10 minutes the first time. So please be patient :)

In [12]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("WizardLM/WizardCoder-15B-V1.0")
model = AutoModelForCausalLM.from_pretrained("WizardLM/WizardCoder-15B-V1.0",trust_remote_code=True,
    # torch_dtype=torch.bfloat16,
    # load_in_8bit=True,
    load_in_4bit=True,
    device_map="auto",
    use_cache=True,)

Downloading pytorch_model.bin:   0%|          | 0.00/31.0G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

##Set the Question & Prompt and Tokenize
Feel free to change the question below. Should you want to experiment with your own database schema, edit the schema in the prompt.

In [100]:
def reproompt():
  print("reproompting")
  # someone make something to either proompt wizardcoder or llama 2
  # nvm i think i did it lol
  # workd but need refining



def validateSQL(response,path):
  # check if safe
    #if not sqlvalidator.parse(response).is_valid():
       # print("SQL Invalid")
       # return

    # check if executable
    try:
      conn = sqlite3.connect(path)
      cursor = conn.cursor()
      cursor.execute(response)
      print("executed")
      sqlResult = ""
      for row in cursor.fetchall():
        sqlResult += str(row[0]) + '\n'


      print("Data: " + str(sqlResult))
    except sqlite3.Error as e:
      print(e) # error
      cursor.close()
      conn.close()
      print("Error executing SQL")
      prompt = """### Instructions:
      Your task is to fix an SQL query, given a sqlite3 database schema and the error code.
      Adhere to these rules:
      - **Deliberately go through the SQL Query and database schema word by word** to appropriately fix the Query according to the database
      - **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
      - When creating a ratio, always cast the numerator as float

      ### Input:
      Fix the following SQL Query `{question}`.
      This query will run on a database whose schema is represented in this string:
      {schema}
      ### Response:
      Based on your instructions, here is the SQL query I have generated to answer the question `{question}`:
      ```sql
      """.format(question=response, schema=schema)
      getResponse(prompt)

    print("Valid")
    return 'Valid'

def connect_db(path):
    conn = sqlite3.connect(path)
    # Create a cursor object
    cursor = conn.cursor()

    # Query to retrieve the schema information for all tables
    query = "SELECT sql FROM sqlite_master WHERE type='table';"

    # Execute the query
    cursor.execute(query)

    # Fetch all the results and concatenate them into a single string
    schema_string = ""
    for row in cursor.fetchall():
        schema_string += row[0] + '\n'

    # Close the cursor and the database connection
    cursor.close()
    conn.close()
    print(schema_string)
    return schema_string

question = "Where was the biggest vaccination rate achieved?"
schema = connect_db("/content/example-covid-vaccinations.sqlite3")

CREATE TABLE covid_vaccinations
    -- COVID-19 Vaccination Rates
    -- URL: https://ws.cso.ie/public/api.restful/PxStat.Data.Cube_API.ReadDataset/CDC45/CSV/1.0/en
(
  STATISTIC_CODE varchar(10), -- Statistic code
  Statistic_Label varchar(30), -- Statistic label
  `TLIST(M1)` int, -- Time period of statistic
  Month varchar(20), -- Time period human-readable
  `C03898V04649` varchar(30),
  `Local Electoral Area` varchar(50),
  `C02076V03371` varchar(10),
  `Age Group` varchar(30),
  `UNIT` varchar(10),
  `VALUE` float
)



In [97]:
prompt = """### Instructions:
Your task is to convert a question into a SQL query, given a sqlite3 database schema.
Adhere to these rules:
- **Deliberately go through the question and database schema word by word** to appropriately answer the question
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
- When creating a ratio, always cast the numerator as float

### Input:
Generate a SQL query that answers the question `{question}`.
This query will run on a database whose schema is represented in this string:
{schema}
### Response:
Based on your instructions, here is the SQL query I have generated to answer the question `{question}`:
```sql
""".format(question=question, schema=schema)

In [98]:
eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]

##Generate the SQL
This can be excruciatingly slow on an V100 with 4bit quantization. It could take around 1-2 minutes per query. On a single A100 40GB, it takes ~10-20 seconds.

In [101]:
def getResponse(prompt):
  print("Question: " + question)
  eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
  generated_ids = model.generate(
      **inputs,
      num_return_sequences=1,
      eos_token_id=eos_token_id,
      pad_token_id=eos_token_id,
      max_new_tokens=400,
      do_sample=False,
      num_beams=5
  )
  outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
  response = outputs[0].split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";"
  print("Response: ")
  print(outputs[0].split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";")

  validateSQL(str(response), "/content/example-covid-vaccinations.sqlite3")
getResponse(prompt)

Question: Where was the biggest vaccination rate achieved?
Response: 
SELECT MAX(CAST(`C03898V04649` AS FLOAT))
FROM covid_vaccinations;
executed
Data: 2.0

Valid


In [7]:
torch.cuda.empty_cache()
torch.cuda.synchronize()
# empty cache so that you do generate more results w/o memory crashing
# particularly important on Colab – memory management is much more straightforward
# when running on an inference service

NameError: ignored

And voila! Here's the generated SQL:

In [19]:
response = outputs[0].split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";"
print(outputs[0].split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";")

NameError: ignored