In [2]:
import re
import json
import requests
import langchain
import langchain_community 
from langchain_community.utilities import SQLDatabase

In [3]:
#Creating SQLite database cursor 
db_path = "./data/Chinook_Sqlite.sqlite"
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")   


In [4]:
def get_schema(_):
    schema = db.get_table_info()
    
    #to remove charachters between /* */ and new line characters.
    schema_cleaned = re.sub(r'/\*.*?\*/', '', schema, flags=re.DOTALL)
    
    #to remove leading and trailing spaces
    schema_cleaned = schema_cleaned.strip()
    
    #to remove extra blank spaces
    schema_cleaned = re.sub(r'\n\s*\n+', '\n\n', schema_cleaned)
    
    splited = schema_cleaned.split('\nCREATE')

    for i in range(1,len(splited)):
        splited[i] = "CREATE" + splited[i]    

    final_schema = "\n".join(splited)

    return final_schema

In [5]:
schema = get_schema(_)
print(schema)

CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

CREATE TABLE "Employee" (
	"EmployeeId" INTEGER NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"FirstName" NVARCHAR(20) NOT NULL, 
	"Title" NVARCHAR(30), 
	"ReportsTo" INTEGER, 
	"BirthDate" DATE

In [9]:
prompt = """
            You are an expert SQL query generator. You will be provided with a database schema and user instructions.\ 
            Your task is to generate accurate and efficient SQL queries based on the given schema and user requirements. \
            Follow these steps:\
     \
     ***instructions*** \
     1.Understand the Schema: \
        Thoroughly examine the provided database schema. \
        Note the tables, columns, data types, and relationships between tables (e.g., primary keys, foreign keys).\
     2.Interpret User Instructions: \ 
        Carefully read and understand the user's requirements for the SQL query. \
        Identify the tables and columns involved, the conditions for filtering data, and the desired output.\
    3.Generate SQL Queries:\
        Write SQL queries that accurately reflect the user's requirements.\
        Ensure the queries are syntactically correct and optimized for performance.\
        Return the Query:\   

    Provide the generated SQL query as the result. \
     Schema :\
         Table: Employees\
    - EmployeeID (INT, Primary Key)\
    - FirstName (VARCHAR)\
    - LastName (VARCHAR)\
    - DepartmentID (INT, Foreign Key)\
    \
    Table: Departments\
    - DepartmentID (INT, Primary Key)\
    - DepartmentName (VARCHAR)\
    \
    Table: Salaries\
    - EmployeeID (INT, Primary Key, Foreign Key)\
    - Salary (DECIMAL)\
    - EffectiveDate (DATE)\

    User Instruction:\
    'create a query to find the names of employees in the 'Sales' department who have a salary greater than $50,000.'\
    
    Just provide SQL query in response and explaination what query does do not add anything extra.
    """
#      Example Output\
#         SELECT Employees.FirstName, Employees.LastName\
#         FROM Employees\
#         JOIN Departments ON Employees.DepartmentID = Departments.DepartmentID\
#         JOIN Salaries ON Employees.EmployeeID = Salaries.EmployeeID\
#         WHERE Departments.DepartmentName = 'Sales'\
#         AND Salaries.Salary > 50000;\
#      '''
# """

In [15]:
url = 'http://localhost:11434/api/generate'

data = {
        "model": 'llama3',
        "prompt": prompt
    }
    
full_response = []

headers = {'Content-Type': 'application/json'}
response = requests.post(url, data=json.dumps(data), headers=headers, stream=True)
try:
    count = 0
    for line in response.iter_lines():
        if line:
            decoded_line = json.loads(line.decode('utf-8'))

            full_response.append(decoded_line['response'])
finally:
    response.close()

print(''.join(full_response))

```
SELECT e.FirstName, e.LastName
FROM Employees e
JOIN Departments d ON e.DepartmentID = d.DepartmentID
WHERE d.DepartmentName = 'Sales' AND EXISTS (
  SELECT 1
  FROM Salaries s
  WHERE s.EmployeeID = e.EmployeeID AND s.Salary > 50000.00
);
```


In [16]:
prompt = """
            You are an expert SQL query generator. You will be provided with a database schema and user instructions.\ 
            Your task is to generate accurate and efficient SQL queries based on the given schema and user requirements. \
            Follow these steps:\
     \
     ***instructions*** \
     1.Understand the Schema: \
        Thoroughly examine the provided database schema. \
        Note the tables, columns, data types, and relationships between tables (e.g., primary keys, foreign keys).\
     2.Interpret User Instructions: \ 
        Carefully read and understand the user's requirements for the SQL query. \
        Identify the tables and columns involved, the conditions for filtering data, and the desired output.\
    3.Generate SQL Queries:\
        Write SQL queries that accurately reflect the user's requirements.\
        Ensure the queries are syntactically correct and optimized for performance.\
        Return the Query:\   

    Provide the generated SQL query as the result. \
    {schema}
    
    User Instruction:\
    {user_question}
    
    Tell me what you have understood and then \
    Just provide SQL query and \
    explaination what query does in the response\
    do not add anything extra.
    """

In [17]:
def generate_answer(schema,question,model="llama3"):
    
    data = {
        "model": model,
        "prompt": prompt.format(schema=schema,user_question=question)
    }
    
    full_response = []

    headers = {'Content-Type': 'application/json'}
    response = requests.post(url, data=json.dumps(data), headers=headers, stream=True)
    try:
        count = 0
        for line in response.iter_lines():
            if line:
                decoded_line = json.loads(line.decode('utf-8'))

                full_response.append(decoded_line['response'])
    finally:
        response.close()
        
    result = ''.join(full_response)
    return result

In [19]:
generate_answer(schema=schema,question="What is the title of the album with AlbumId 67")

"I understand that the user wants to know the title of the album with AlbumId 67. \n\nSQL Query:\n```sql\nSELECT Title \nFROM Album \nWHERE AlbumId = 67;\n```\nExplanation: This SQL query selects the 'Title' column from the 'Album' table where the 'AlbumId' is equal to 67. The WHERE clause filters the results to only include the album with AlbumId 67, and then the SELECT statement retrieves its title."

In [21]:
text = "I understand that the user wants to know the title of the album with AlbumId 67. \n\nSQL Query:\n```sql\nSELECT Title \nFROM Album \nWHERE AlbumId = 67;\n```\nExplanation: This SQL query selects the 'Title' column from the 'Album' table where the 'AlbumId' is equal to 67. The WHERE clause filters the results to only include the album with AlbumId 67, and then the SELECT statement retrieves its title."
print(text)

I understand that the user wants to know the title of the album with AlbumId 67. 

SQL Query:
```sql
SELECT Title 
FROM Album 
WHERE AlbumId = 67;
```
Explanation: This SQL query selects the 'Title' column from the 'Album' table where the 'AlbumId' is equal to 67. The WHERE clause filters the results to only include the album with AlbumId 67, and then the SELECT statement retrieves its title.


In [36]:
prompt = """
            You are an expert SQL query generator. You will be provided with a database schema and user instructions.\ 
            Your task is to generate accurate and efficient SQL queries based on the given schema and user requirements. \
            Follow these steps:\
     \
     ***instructions*** \
     1.Understand the Schema: \
        Thoroughly examine the provided database schema. \
        Note the tables, columns, data types, and relationships between tables (e.g., primary keys, foreign keys).\
     2.Interpret User Instructions: \ 
        Carefully read and understand the user's requirements for the SQL query. \
        Identify the tables and columns involved, the conditions for filtering data, and the desired output.\
    3.Generate SQL Queries:\
        Write SQL queries that accurately reflect the user's requirements.\
        Ensure the queries are syntactically correct and optimized for performance.\
        Return the Query:\   

    Provide the generated SQL query as the result. \
    {schema}
    
    User Instruction:\
    {user_question}
    
    ***strict instructions***
    Just provide SQL query and \
    Do not add anything extra.
    """

In [13]:
def generate_answer(schema,question,model="llama3"):
    
    data = {
        "model": model,
        "prompt": prompt.format(schema=schema,user_question=question)
    }
    
    full_response = []

    headers = {'Content-Type': 'application/json'}
    response = requests.post(url, data=json.dumps(data), headers=headers, stream=True)
    try:
        count = 0
        for line in response.iter_lines():
            if line:
                decoded_line = json.loads(line.decode('utf-8'))

                full_response.append(decoded_line['response'])
    finally:
        response.close()
        
    result = ''.join(full_response)
    return result

In [24]:
print(generate_answer(schema=schema,question="What is the title of the album with AlbumId 67"))

SELECT Title
FROM Album
WHERE AlbumId = 67;


In [38]:
from langchain.llms import Ollama
from langchain.schema.runnable import RunnablePassthrough
from langchain.prompts import PromptTemplate

prompt_temp = PromptTemplate(template = prompt,input_variables=['schema','user_question'])

llm = Ollama(
    base_url='http://localhost:11434',
    model='llama3',
)

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt_temp
    | llm
)

In [42]:
user_question = 'how many albums are there in the database?'
sql_chain.invoke({"user_question": user_question})

# 'SELECT COUNT(*) AS TotalAlbums\nFROM Album;'

'SELECT COUNT(*) FROM Album;'

In [57]:
from sqlalchemy import create_engine, inspect, MetaData, text
import time

def run_query(sql_query, db_uri):
    print(f"SQL Query:{sql_query}")
    engine = create_engine(db_uri)
    with engine.connect() as conn:
        result = conn.execute(text(sql_query))
#         results_as_dict = result.mappings().all()
        result = result.fetchall()
#         print(f"result:{result_as_dict}")
        print(f"result:{result}")            
        return result_as_dict

In [58]:
uri = "sqlite:///./data/Chinook_Sqlite.sqlite"
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response=lambda vars: run_query(vars["query"],uri)  
    )
    | p_t
    | llm
)

In [59]:
prompt_t = """
            answer the user question.
            {user_question}
"""

p_t = PromptTemplate(template=prompt_t,input_variables=["user_question"])

In [None]:
user_question = 'how many albums are there in the database?'
full_chain.invoke({"user_question": user_question})

# 'There are 347 albums in the database.'

SQL Query:SELECT COUNT(*) AS Total_Albums
FROM "Album";
result:[(347,)]
0


In [None]:
from langchain_experimental.sql import SQLDatabaseChain

question = "how many albums are there in the database?"

mysql_uri = 'mysql+mysqlconnector://root:keval@localhost:3306/chinook'

db1 = SQLDatabase.from_uri(mysql_uri)

prompt = """
    You are an expert SQL query generator.
            
    Just provide SQL query and \
    Do not add anything extra.
    """

prompt_template = PromptTemplate(template=prompt,input_variables=[])

# Create the SQL database chain
db_chain = SQLDatabaseChain.from_llm(
    llm=llm, 
    database=db,
    verbose = True
)

db_chain.run(question)                                                                                          