In [1]:
import re
import langchain
import langchain_community 
from langchain_community.utilities import SQLDatabase

In [2]:
#Creating SQLite database cursor 
db_path = "./data/Chinook_Sqlite.sqlite"
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")   

In [3]:
mysql_uri = 'mysql+mysqlconnector://root:keval@localhost:3306/chinook'
db1 = SQLDatabase.from_uri(mysql_uri)

In [4]:
def get_schema(_):
    schema = db.get_table_info()
    
    #to remove charachters between /* */ and new line characters.
    schema_cleaned = re.sub(r'/\*.*?\*/', '', schema, flags=re.DOTALL)
    
    #to remove leading and trailing spaces
    schema_cleaned = schema_cleaned.strip()
    
    #to remove extra blank spaces
    schema_cleaned = re.sub(r'\n\s*\n+', '\n\n', schema_cleaned)
    
    splited = schema_cleaned.split('\nCREATE')

    for i in range(1,len(splited)):
        splited[i] = "CREATE" + splited[i]    

    final_schema = "\n".join(splited)

    return final_schema

In [8]:
#SQLite
print(get_schema(db))

CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

CREATE TABLE "Employee" (
	"EmployeeId" INTEGER NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"FirstName" NVARCHAR(20) NOT NULL, 
	"Title" NVARCHAR(30), 
	"ReportsTo" INTEGER, 
	"BirthDate" DATE

In [6]:
#MySQL
print(get_schema(db1))

CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

CREATE TABLE "Employee" (
	"EmployeeId" INTEGER NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"FirstName" NVARCHAR(20) NOT NULL, 
	"Title" NVARCHAR(30), 
	"ReportsTo" INTEGER, 
	"BirthDate" DATE

In [31]:
template = """
Based on the table schema below, write a SQL query that would answer the user's question:

{schema}

Question: 
{question}

SQL Query:

"""

prompt = ChatPromptTemplate.from_template(template)

In [43]:
llm = Ollama(
    base_url='http://localhost:11434',
    model='llama3',
)

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
)

In [44]:
user_question = 'how many albums are there in the database?'
sql_chain.invoke({"question": user_question})

# 'SELECT COUNT(*) AS TotalAlbums\nFROM Album;'

To answer this question, you can use a SQL query that selects the count of rows from the `Album` table. Here is the query:

```sql
SELECT COUNT(*) AS Total_Albums
FROM Album;
```

This query will return a single row with one column named `Total_Albums`, which contains the total number of albums in the database.

'To answer this question, you can use a SQL query that selects the count of rows from the `Album` table. Here is the query:\n\n```sql\nSELECT COUNT(*) AS Total_Albums\nFROM Album;\n```\n\nThis query will return a single row with one column named `Total_Albums`, which contains the total number of albums in the database.'

In [80]:
template = """You are an expert SQL query generator. You will be provided with a database schema and user instructions.\ 
            Your task is to generate accurate and efficient SQL queries based on the given schema and user requirements. \
            Follow these steps:\
     \
     ***instructions*** \
     1.Understand the Schema: \
        Thoroughly examine the provided database schema. \
        Note the tables, columns, data types, and relationships between tables (e.g., primary keys, foreign keys).\
     2.Interpret User Instructions: \ 
        Carefully read and understand the user's requirements for the SQL query. \
        Identify the tables and columns involved, the conditions for filtering data, and the desired output.\
    3.Generate SQL Queries:\
        Write SQL queries that accurately reflect the user's requirements.\
        Ensure the queries are syntactically correct and optimized for performance.\
        Return the Query:\   

    Provide the generated SQL query as the result. \
    {schema}
    
    User Instruction:\
    {question}
    
    ```strict instruction```
    Just provide SQL query and \
    do not add anything extra.
"""

prompt_response = PromptTemplate.from_template(template)

In [81]:
from sqlalchemy import create_engine, inspect, MetaData, text

def run_query(sql_query, db_uri):
    print(sql_query)
    engine = create_engine(db_uri)
    with engine.connect() as conn:
        result = conn.execute(text(sql_query))
        return [dict(row) for row in result]

In [84]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response=lambda vars: run_query(vars['query'],mysql_uri)
    )
    | prompt_response
    | llm
)

In [None]:
def query_llm(prompt_value, max_new_tokens=200):
    prompt_str = str(prompt_value)
#     print("Prompt to model:\n", prompt_str)  # Print the prompt to verify it
    inputs = tokenizer(prompt_str, return_tensors="pt", max_length=1024, truncation=True)

    outputs = model.generate(
        inputs.input_ids,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.eos_token_id,
        num_return_sequences=1
    )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
#     print("Generated text:\n", generated_text)  # Print the generated text to verify it

    # Ensure only the SQL query is extracted
    sql_query_match = re.search(r'SQL Query:\s*(.*?)(\n|$)', generated_text)
    sql_query = sql_query_match.group(1).strip() if sql_query_match else generated_text.strip()
    
    # Further clean the SQL query from any extra content
    sql_query = sql_query.split("Answer:")[0].strip()
    return sql_query

In [None]:
sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | query_llm
    | StrOutputParser()
)

In [85]:
user_question = 'how many albums are there in the database?'
full_chain.invoke({"question": user_question})

# 'There are 347 albums in the database.'

Here is the SQL query that answers the user's question:

```sql
SELECT COUNT(*) 
FROM Album;
```

This query uses the `COUNT(*)` function to count the number of rows in the `Album` table, which corresponds to the number of albums in the database. The result will be a single row with one column containing the count of albums.SELECT count(*) from album;


TypeError: cannot convert dictionary update sequence element #0 to a sequence