In [70]:
import re
import torch
import langchain
import transformers 
import langchain_core
import langchain_community 
from pprint import pprint as pp

from langchain.llms import Ollama
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.utilities import SQLDatabase

In [2]:
prompt_template = """
Based on the table schema below, write an SQL query that would answer the user's question:

Schema: 
{schema}

Question: 
{question}

SQL Query:

"""

In [41]:
prompt = ChatPromptTemplate.from_template(prompt_template)

db_path = "./data/Chinook_Sqlite.sqlite"
# db_path = "./data/airline_data_new.db"
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")    

In [44]:
mysql_uri = 'mysql+mysqlconnector://root:keval@localhost:3306/chinook'
db1 = SQLDatabase.from_uri(mysql_uri)

In [42]:
def get_schema(db):
    schema = db.get_table_info()
    
    schema_cleaned = re.sub(r'/\*.*?\*/', '', schema, flags=re.DOTALL)
    schema_cleaned = schema_cleaned.strip()
    schema_cleaned = re.sub(r'\n\s*\n+', '\n\n', schema_cleaned)
    
    splited = schema_cleaned.split('\nCREATE')

    for i in range(1,len(splited)):
        splited[i] = "CREATE" + splited[i]    

    final_schema = "\n".join(splited)

    return final_schema

In [65]:
print(get_schema(db))

CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

CREATE TABLE "Employee" (
	"EmployeeId" INTEGER NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"FirstName" NVARCHAR(20) NOT NULL, 
	"Title" NVARCHAR(30), 
	"ReportsTo" INTEGER, 
	"BirthDate" DATE

In [45]:
print(get_schema(db1))

CREATE TABLE album (
	`AlbumId` INTEGER NOT NULL, 
	`Title` VARCHAR(160) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL, 
	`ArtistId` INTEGER NOT NULL, 
	PRIMARY KEY (`AlbumId`), 
	CONSTRAINT `FK_AlbumArtistId` FOREIGN KEY(`ArtistId`) REFERENCES artist (`ArtistId`)
)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB

CREATE TABLE artist (
	`ArtistId` INTEGER NOT NULL, 
	`Name` VARCHAR(120) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, 
	PRIMARY KEY (`ArtistId`)
)DEFAULT CHARSET=utf8mb4 COLLATE utf8mb4_0900_ai_ci ENGINE=InnoDB

CREATE TABLE customer (
	`CustomerId` INTEGER NOT NULL, 
	`FirstName` VARCHAR(40) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL, 
	`LastName` VARCHAR(20) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci NOT NULL, 
	`Company` VARCHAR(80) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, 
	`Address` VARCHAR(70) CHARACTER SET utf8mb3 COLLATE utf8mb3_general_ci, 
	`City` VARCHAR(40) CHARACTER SET utf8mb3 COLLATE utf8mb3_genera

In [67]:
template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""
prompt = ChatPromptTemplate.from_template(template)

In [68]:


llm = Ollama(
    
)

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)
