In [85]:
import openai  # establish connection with OpenAI API 
import duckdb # to simulate SQL commands
import time, os  #  to load CSV files and reformat certain fields
import pandas as pd # data processing 

In [82]:
key = "Insert your OpenAI key here...."

In [8]:
files = [x for x in os.listdir(path = "../LLM_SQL") if ".csv" in x]
files

['Chicago-crimes-2021.csv',
 'Chicago-crimes-2022.csv',
 'Chicago-crimes-2023.csv']

In [9]:
chicago_crime = pd.concat((pd.read_csv("../LLM_SQL" +"/" + f) for f in files), ignore_index=True)
chicago_crime.head()

Unnamed: 0,ID,Case Number,Date,Block,IUCR,Primary Type,Description,Location Description,Arrest,Domestic,...,Ward,Community Area,FBI Code,X Coordinate,Y Coordinate,Year,Updated On,Latitude,Longitude,Location
0,25953,JE240540,05/24/2021 03:06:00 PM,020XX N LARAMIE AVE,110,HOMICIDE,FIRST DEGREE MURDER,STREET,True,False,...,36.0,19,01A,1141387.0,1913179.0,2021,11/18/2023 03:39:49 PM,41.917838,-87.755969,"(41.917838056, -87.755968972)"
1,26038,JE279849,06/26/2021 09:24:00 AM,062XX N MC CORMICK RD,110,HOMICIDE,FIRST DEGREE MURDER,PARKING LOT,True,False,...,50.0,13,01A,1152781.0,1941458.0,2021,11/18/2023 03:39:49 PM,41.995219,-87.713355,"(41.995219444, -87.713354912)"
2,12342615,JE202211,04/17/2021 03:20:00 PM,081XX S PRAIRIE AVE,325,ROBBERY,VEHICULAR HIJACKING,RESIDENCE,True,False,...,6.0,44,03,1179448.0,1851073.0,2021,09/14/2023 03:41:59 PM,41.746626,-87.618032,"(41.746626309, -87.618031954)"
3,26262,JE366265,09/08/2021 04:45:00 PM,047XX W HARRISON ST,110,HOMICIDE,FIRST DEGREE MURDER,CAR WASH,True,False,...,24.0,25,01A,1144907.0,1896933.0,2021,09/14/2023 03:41:59 PM,41.873191,-87.743447,"(41.873191445, -87.743446563)"
4,13209581,JG422927,08/01/2021 12:00:00 AM,012XX E 78TH ST,1563,SEX OFFENSE,CRIMINAL SEXUAL ABUSE,APARTMENT,False,False,...,8.0,45,17,,,2021,09/14/2023 03:43:09 PM,,,


In [10]:
duckdb.sql("DESCRIBE SELECT * FROM chicago_crime;")

┌──────────────────────┬─────────────┬─────────┬─────────┬─────────┬─────────┐
│     column_name      │ column_type │  null   │   key   │ default │  extra  │
│       varchar        │   varchar   │ varchar │ varchar │ varchar │ varchar │
├──────────────────────┼─────────────┼─────────┼─────────┼─────────┼─────────┤
│ ID                   │ BIGINT      │ YES     │ NULL    │ NULL    │ NULL    │
│ Case Number          │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ Date                 │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ Block                │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ IUCR                 │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ Primary Type         │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ Description          │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ Location Description │ VARCHAR     │ YES     │ NULL    │ NULL    │ NULL    │
│ Arrest               │ BOOLEAN     │ YES     │ NUL

In [39]:
import re
tdf = duckdb.sql("DESCRIBE SELECT * FROM chicago_crime;")
df = tdf.to_df()[['column_name','column_type']]
df["column_joint"] = df["column_name"] + " " +  df["column_type"]
x = str(df['column_joint'].values)

output_string = re.sub(r"[',\[\]]", "", x)
output_string

'ID BIGINT Case Number VARCHAR Date VARCHAR Block VARCHAR\n IUCR VARCHAR Primary Type VARCHAR Description VARCHAR\n Location Description VARCHAR Arrest BOOLEAN Domestic BOOLEAN\n Beat BIGINT District BIGINT Ward DOUBLE Community Area BIGINT\n FBI Code VARCHAR X Coordinate DOUBLE Y Coordinate DOUBLE\n Year BIGINT Updated On VARCHAR Latitude DOUBLE Longitude DOUBLE\n Location VARCHAR'

In [11]:
duckdb.sql("SELECT count(*) FROM chicago_crime;")

┌──────────────┐
│ count_star() │
│    int64     │
├──────────────┤
│       688056 │
└──────────────┘

In [12]:
duckdb.sql("SELECT count(*) FROM chicago_crime where year='2021';")

┌──────────────┐
│ count_star() │
│    int64     │
├──────────────┤
│       208976 │
└──────────────┘

In [14]:
duckdb.sql("SELECT count(*) as Arrest_Total FROM chicago_crime where Arrest=True;")

┌──────────────┐
│ Arrest_Total │
│    int64     │
├──────────────┤
│        82297 │
└──────────────┘

In [98]:
def create_message(table_name, query):
    class message:
        def __init__(message, system, user, column_names, column_attr):
            message.system = system
            message.user = user
            message.column_names = column_names
            message.column_attr = column_attr
            
    system_template = """
    Given the following SQL table, your job is to write SQL queries given a user's request. \n
    create table {} ({}) \n
    
    Note : Only SQL Query should be returned, no need to explain the query. \n
    Also there is no requiremnet of any desclaimer or alert or instructions before or after the query.
    """
    
    user_template = """
    Write a SQL query that returns - {} 
    """
    
    tbl_describe = duckdb.sql("DESCRIBE SELECT * FROM " + table_name +  ";")
    col_attr = tbl_describe.to_df()[['column_name','column_type']]
    col_attr["column_joint"] = col_attr["column_name"] + " " +  col_attr["column_type"]
    col_names = re.sub(r"[',\[\]]", "", str(col_attr["column_joint"].values))
    
    system = system_template.format(table_name, col_names)
    user = user_template.format(query)

    m = message(system = system, user = user, 
                column_names = col_attr["column_name"], 
                column_attr = col_attr["column_type"])
    return m

In [47]:
query = "How many cases ended up with arrest?"
msg = create_message(table_name = "chicago_crime", query = query)

In [48]:
print(msg.system)


    Given the following SQL table, your job is to write SQL queries given a user's request. 

    create table chicago_crime (ID BIGINT Case Number VARCHAR Date VARCHAR Block VARCHAR
 IUCR VARCHAR Primary Type VARCHAR Description VARCHAR
 Location Description VARCHAR Arrest BOOLEAN Domestic BOOLEAN
 Beat BIGINT District BIGINT Ward DOUBLE Community Area BIGINT
 FBI Code VARCHAR X Coordinate DOUBLE Y Coordinate DOUBLE
 Year BIGINT Updated On VARCHAR Latitude DOUBLE Longitude DOUBLE
 Location VARCHAR) 

    


In [49]:
print(msg.user)


    Write a SQL query that returns - How many cases ended up with arrest? 
    


In [50]:
print(msg.column_names)

0                       ID
1              Case Number
2                     Date
3                    Block
4                     IUCR
5             Primary Type
6              Description
7     Location Description
8                   Arrest
9                 Domestic
10                    Beat
11                District
12                    Ward
13          Community Area
14                FBI Code
15            X Coordinate
16            Y Coordinate
17                    Year
18              Updated On
19                Latitude
20               Longitude
21                Location
Name: column_name, dtype: object


In [51]:
print(msg.column_attr)

0      BIGINT
1     VARCHAR
2     VARCHAR
3     VARCHAR
4     VARCHAR
5     VARCHAR
6     VARCHAR
7     VARCHAR
8     BOOLEAN
9     BOOLEAN
10     BIGINT
11     BIGINT
12     DOUBLE
13     BIGINT
14    VARCHAR
15     DOUBLE
16     DOUBLE
17     BIGINT
18    VARCHAR
19     DOUBLE
20     DOUBLE
21    VARCHAR
Name: column_type, dtype: object


In [63]:
openai.api_key = key

for names in openai.models.list():
    print(names)

Model(id='text-search-babbage-doc-001', created=1651172509, object='model', owned_by='openai-dev')
Model(id='curie-search-query', created=1651172509, object='model', owned_by='openai-dev')
Model(id='text-davinci-003', created=1669599635, object='model', owned_by='openai-internal')
Model(id='text-search-babbage-query-001', created=1651172509, object='model', owned_by='openai-dev')
Model(id='babbage', created=1649358449, object='model', owned_by='openai')
Model(id='babbage-search-query', created=1651172509, object='model', owned_by='openai-dev')
Model(id='text-babbage-001', created=1649364043, object='model', owned_by='openai')
Model(id='text-similarity-davinci-001', created=1651172505, object='model', owned_by='openai-dev')
Model(id='davinci-similarity', created=1651172509, object='model', owned_by='openai-dev')
Model(id='code-davinci-edit-001', created=1649880484, object='model', owned_by='openai')
Model(id='curie-similarity', created=1651172510, object='model', owned_by='openai-dev')


In [99]:
query = "How many cases ended up with arrest?"
prompt = create_message(table_name = "chicago_crime", query = query)


client = openai.OpenAI()

# to transform the above prompt into the structure of the ChatCompletion.create function messages argument:
message = [
     {
         "role": "system",
         "content": prompt.system
     },
     {
         "role": "user",
         "content": prompt.user
     }]
    
# now send the prompt (message object) to the API using the ChatCompletion.create function:
response = client.chat.completions.create(
        model = "gpt-3.5-turbo",
        messages = message,
        temperature = 0,
        max_tokens = 256)

In [100]:
print(response)

ChatCompletion(id='chatcmpl-8UqyA7NmFKWCK6HM6kC6gXH9M5vUk', choices=[Choice(finish_reason='stop', index=0, message=ChatCompletionMessage(content='SELECT COUNT(*) FROM chicago_crime WHERE Arrest = true;', role='assistant', function_call=None, tool_calls=None))], created=1702364146, model='gpt-3.5-turbo-0613', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=12, prompt_tokens=168, total_tokens=180))


In [101]:
sql_query = response.choices[0].message.content
print(sql_query)

SELECT COUNT(*) FROM chicago_crime WHERE Arrest = true;


In [103]:
# queyr valudation 
duckdb.sql(sql_query).show()

┌──────────────┐
│ count_star() │
│    int64     │
├──────────────┤
│        82297 │
└──────────────┘



In [104]:
# helper function to add quotes to the variables in the query if the returned query does not have one:
def add_quotes(query, col_names):
    for i in col_names:
        if i in query:
            l = query.find(i)
            if query[l-1] != "'" and query[l-1] != '"': 
                query = str(query).replace(i, '"' + i + '"') 
    return(query)

In [105]:
add_quotes(query = sql_query, col_names = prompt.column_names)

'SELECT COUNT(*) FROM chicago_crime WHERE "Arrest" = true;'

In [106]:
def lang2sql(api_key, table_name, query, model = "gpt-3.5-turbo", temperature = 0, 
             max_tokens = 256, frequency_penalty = 0,presence_penalty= 0):
    class response:
        def __init__(output, message, response, sql):
            output.message = message
            output.response = response
            output.sql = sql

    openai.api_key = api_key

    m = create_message(table_name = table_name, query = query)

    message = [
    {
      "role": "system",
      "content": m.system
    },
    {
      "role": "user",
      "content": m.user
    }
    ]
    
    openai_response = client.chat.completions.create(
        model = model,
        messages = message,
        temperature = temperature,
        max_tokens = max_tokens,
        frequency_penalty=frequency_penalty,
        presence_penalty=presence_penalty
    )
    
    sql_query = add_quotes(query = openai_response.choices[0].message.content, col_names = m.column_names)

    output = response(message = m, response = openai_response, sql = sql_query)

    return output

In [108]:
query = "How many cases ended up with arrest?"
response = lang2sql(api_key = key, table_name = "chicago_crime", query = query)

In [110]:
print(response.sql)

SELECT COUNT(*) FROM chicago_crime WHERE "Arrest" = true;


In [112]:
query = "How many cases ended up with arrest during 2022"
response = lang2sql(api_key = key, table_name = "chicago_crime", query = query)
print(response.sql)

SELECT COUNT(*) FROM chicago_crime WHERE "Arrest" = TRUE AND "Year" = 2022;


In [115]:
duckdb.sql(response.sql)

┌──────────────┐
│ count_star() │
│    int64     │
├──────────────┤
│        27845 │
└──────────────┘

In [117]:
query = "Summarize the cases by primary type"
response = lang2sql(api_key = key, table_name = "chicago_crime", query = query)

print(response.sql)

SELECT "Primary Type", COUNT(*) AS TotalCases
FROM chicago_crime
GROUP BY "Primary Type";


In [118]:
duckdb.sql(response.sql)

┌───────────────────────────────────┬────────────┐
│           Primary Type            │ TotalCases │
│              varchar              │   int64    │
├───────────────────────────────────┼────────────┤
│ BURGLARY                          │      21113 │
│ NARCOTICS                         │      14753 │
│ CRIMINAL DAMAGE                   │      80072 │
│ INTIMIDATION                      │        529 │
│ CRIMINAL TRESPASS                 │      11973 │
│ OFFENSE INVOLVING CHILDREN        │       5396 │
│ STALKING                          │       1287 │
│ PUBLIC PEACE VIOLATION            │       2119 │
│ PROSTITUTION                      │        586 │
│ ARSON                             │       1415 │
│   ·                               │         ·  │
│   ·                               │         ·  │
│   ·                               │         ·  │
│ ROBBERY                           │      27021 │
│ WEAPONS VIOLATION                 │      25848 │
│ OBSCENITY                    

In [120]:
query = "How many cases is the type of robbery?"
response = lang2sql(api_key = key, table_name = "chicago_crime", query = query)

print(response.sql)

SELECT COUNT(*) FROM chicago_crime WHERE "Primary Type" = 'ROBBERY';


In [121]:
duckdb.sql(response.sql)

┌──────────────┐
│ count_star() │
│    int64     │
├──────────────┤
│        27021 │
└──────────────┘