In [1]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.utilities import SQLDatabase
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from pprint import pprint
import torch
import re

In [2]:
template = """
Based on the table schema below, write an SQL query that would answer the user's question:
Schema: {schema}
Question: {question}
SQL Query:
"""

In [3]:
prompt = ChatPromptTemplate.from_template(template)

db_path = "./data/Chinook.db"
# db_path = "./data/airline_data_new.db"
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")    

In [4]:
def get_create_table_statements(schema_info):
    create_table_statements = re.findall(r'CREATE TABLE.*?\);', schema_info, re.DOTALL)
    formatted_schema = "\n\n".join(create_table_statements)
    return formatted_schema

In [5]:
def get_schema(_):
    schema = db.get_table_info()
    schema_cleaned = re.sub(r'/\*.*?\*/', '', schema, flags=re.DOTALL)
    schema_cleaned = schema_cleaned.strip()
    schema_cleaned = re.sub(r'\n\s*\n+', '\n\n', schema_cleaned)

    return schema_cleaned

In [6]:
# model_name = "juierror/text-to-sql-with-table-schema"
model_name = "PipableAI/pip-sql-1.3b"

tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
"""
def query_llm(prompt_value, max_new_tokens=200):
    prompt_str = str(prompt_value)
    inputs = tokenizer(prompt_str, return_tensors="pt")

    outputs = model.generate(
        inputs.input_ids,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.eos_token_id,
        num_return_sequences=1
    )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    sql_query = generated_text.split("SQL Query:")[1].strip() if "SQL Query:" in generated_text else generated_text
    return sql_query
"""

def query_llm(prompt_value, max_new_tokens=200):
    prompt_str = str(prompt_value)
    print("Prompt to model:\n", prompt_str)  # Print the prompt to verify it
    inputs = tokenizer(prompt_str, return_tensors="pt", max_length=1024, truncation=True)

    outputs = model.generate(
        inputs.input_ids,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.eos_token_id,
        num_return_sequences=1
    )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("Generated text:\n", generated_text)  # Print the generated text to verify it

    # Ensure only the SQL query is extracted
    sql_query_match = re.search(r'SQL Query:\s*(.*?)(\n|$)', generated_text)
    sql_query = sql_query_match.group(1).strip() if sql_query_match else generated_text.strip()
    
    # Further clean the SQL query from any extra content
    sql_query = sql_query.split("Answer:")[0].strip()
    return sql_query

In [8]:
sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | query_llm
    | StrOutputParser()
)

In [9]:
result = sql_chain.invoke({"question": "How many artists are there?"})
print(result)

Prompt to model:
 messages=[HumanMessage(content='\nBased on the table schema below, write an SQL query that would answer the user\'s question:\nSchema: CREATE TABLE "Album" (\n\t"AlbumId" INTEGER NOT NULL, \n\t"Title" NVARCHAR(160) NOT NULL, \n\t"ArtistId" INTEGER NOT NULL, \n\tPRIMARY KEY ("AlbumId"), \n\tFOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")\n)\n\nCREATE TABLE "Artist" (\n\t"ArtistId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("ArtistId")\n)\n\nCREATE TABLE "Customer" (\n\t"CustomerId" INTEGER NOT NULL, \n\t"FirstName" NVARCHAR(40) NOT NULL, \n\t"LastName" NVARCHAR(20) NOT NULL, \n\t"Company" NVARCHAR(80), \n\t"Address" NVARCHAR(70), \n\t"City" NVARCHAR(40), \n\t"State" NVARCHAR(40), \n\t"Country" NVARCHAR(40), \n\t"PostalCode" NVARCHAR(10), \n\t"Phone" NVARCHAR(24), \n\t"Fax" NVARCHAR(24), \n\t"Email" NVARCHAR(60) NOT NULL, \n\t"SupportRepId" INTEGER, \n\tPRIMARY KEY ("CustomerId"), \n\tFOREIGN KEY("SupportRepId") REFERENCES "Employee" ("Employe



# Remove the comments using a regular expression
schema_cleaned = re.sub(r'/\*.*?\*/', '', schema, flags=re.DOTALL)

schema_cleaned = schema_cleaned.strip()
schema_cleaned = re.sub(r'\n\s*\n+', '\n\n', schema_cleaned)
print(schema_cleaned)

In [10]:
print(result)

messages=[HumanMessage(content='\nBased on the table schema below, write an SQL query that would answer the user\'s question:\nSchema: CREATE TABLE "Album" (\n\t"AlbumId" INTEGER NOT NULL, \n\t"Title" NVARCHAR(160) NOT NULL, \n\t"ArtistId" INTEGER NOT NULL, \n\tPRIMARY KEY ("AlbumId"), \n\tFOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")\n)\n\nCREATE TABLE "Artist" (\n\t"ArtistId" INTEGER NOT NULL, \n\t"Name" NVARCHAR(120), \n\tPRIMARY KEY ("ArtistId")\n)\n\nCREATE TABLE "Customer" (\n\t"CustomerId" INTEGER NOT NULL, \n\t"FirstName" NVARCHAR(40) NOT NULL, \n\t"LastName" NVARCHAR(20) NOT NULL, \n\t"Company" NVARCHAR(80), \n\t"Address" NVARCHAR(70), \n\t"City" NVARCHAR(40), \n\t"State" NVARCHAR(40), \n\t"Country" NVARCHAR(40), \n\t"PostalCode" NVARCHAR(10), \n\t"Phone" NVARCHAR(24), \n\t"Fax" NVARCHAR(24), \n\t"Email" NVARCHAR(60) NOT NULL, \n\t"SupportRepId" INTEGER, \n\tPRIMARY KEY ("CustomerId"), \n\tFOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")\n)\n\nCREATE

In [11]:
print(result.messages)

AttributeError: 'str' object has no attribute 'messages'