### Testing SQL Querying Using Langchain

In [4]:
import os
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser, JsonKeyOutputFunctionsParser
from langchain.chat_models import ChatOllama
from langchain.llms import Ollama
from langchain.schema.runnable import RunnablePassthrough
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain_experimental.sql import SQLDatabaseChain
from langchain.prompts import PromptTemplate
import json
from IPython.display import display, Markdown, Latex

In [4]:
os.environ["OPENAI_API_KEY"] = ""

#### Connect to DB using Langchain and test dummy questions

Notice that **sql_chain** will take in a **question** and **generate a sql query**

In [51]:
openai_llm = ChatOpenAI(model="gpt-3.5-turbo-16k", temperature=0)

db = SQLDatabase.from_uri("sqlite:///DBs/Chinook.db")

sql_chain = create_sql_query_chain(llm=openai_llm, db=db)

In [72]:
print(sql_chain.__dict__["middle"][0].template)

You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: 

In [12]:
sql_query = sql_chain.invoke(
    {
        "question": "How many employees are in the database?"
    }
)

In [13]:
sql_query

'SELECT COUNT(*) FROM Employee'

We can run the resultant query directly (wink wink piping opportunity) in DB

In [16]:
db.run(sql_query)

'[(8,)]'

**Be careful with your expectations of the questions asked** as seen below, this can have limitations of what it can do

In [31]:
sql_query = sql_chain.invoke(
    {
        "question": "what is my data about?"
    }
)

The question was very broad, and GPT 3.5 tried "its best"

In [32]:
sql_query


"SELECT COUNT(*) AS TotalTables FROM sqlite_master WHERE type='table';"

In [33]:
db.run(sql_query)

'[(11,)]'

In [34]:
sql_query = sql_chain.invoke(
    {
        "question": "How many customers are there in the database?"
    }
)

In [36]:
db.run(sql_query)

'[(59,)]'

In [37]:
sql_query = sql_chain.invoke(
    {
        "question": "How many employees born after 1970 are there?"
    }
)

In [38]:
sql_query

"SELECT COUNT(*) FROM Employee WHERE BirthDate > '1970-01-01'"

In [39]:
db.run(sql_query)

'[(3,)]'

In [41]:
sql_query = sql_chain.invoke(
    {
        "question": "List the names of the employees born after 1973"
    }
)

sql_query

'SELECT "FirstName", "LastName" \nFROM "Employee" \nWHERE "BirthDate" > \'1973-01-01\''

In [42]:
db.run(sql_query)

"[('Jane', 'Peacock'), ('Michael', 'Mitchell')]"

#### Test with Local LLM, ie. Llama using Custom Prompt Template

In [25]:
template = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}.

Question: {input}"""

custom_prompt = ChatPromptTemplate.from_template(template)

In [42]:
ollama_llm = "llama2:13b"

llama_llm = ChatOllama(model=ollama_llm, temperature=0)

db = SQLDatabase.from_uri("sqlite:///DBs/Chinook.db")

sql_chain = create_sql_query_chain(llm=llama_llm, db=db)

In [43]:
sql_chain

{
  input: RunnableLambda(...),
  top_k: RunnableLambda(...),
  table_info: RunnableLambda(...)
}
| PromptTemplate(input_variables=['input', 'table_info', 'top_k'], template='You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attentio

In [44]:
sql_query = sql_chain.invoke({
    "question" : "List the names of the employees born after 1973"
})

sql_query

"SELECT Name\nFROM Employee\nWHERE BirthDate > '1973-01-01'\n\nThis query will select the names of all employees who were born after 1973. The `BirthDate` column is a date data type, and the `>` operator checks if the date is greater than the specified date. In this case, we are comparing the birth dates to the date '1973-01-01', which is the first day of 1973. Any employee with a birth date after this date will have their name selected.\n\nThe result set will contain the names of all employees who were born after 1973, like this:\n\nResult Set:\nName\n---\nJenifer\nMike\nSarah"

In [45]:
print(sql_query)

SELECT Name
FROM Employee
WHERE BirthDate > '1973-01-01'

This query will select the names of all employees who were born after 1973. The `BirthDate` column is a date data type, and the `>` operator checks if the date is greater than the specified date. In this case, we are comparing the birth dates to the date '1973-01-01', which is the first day of 1973. Any employee with a birth date after this date will have their name selected.

The result set will contain the names of all employees who were born after 1973, like this:

Result Set:
Name
---
Jenifer
Mike
Sarah


#### Test the SQL agents with ollama

In [54]:
agent_executor = create_sql_agent(
    llm = ChatOllama(model=ollama_llm, temperature=0),
    toolkit=SQLDatabaseToolkit(db=db, llm=ChatOllama(model=ollama_llm, temperature=0)),
    verbose=True, handle_parsing_errors=True
)

In [55]:
agent_executor.run("What is my data about?")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m  Sure! I'd be happy to help you with your SQL database queries.

Question: What is my data about?
Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.

Action: sql_db_list_tables

Action Input: An empty string[0m
Observation: [38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m
Thought:[32;1m[1;3m  Question: What is my data about?

Thought: Based on the list of tables, it seems that the database is related to music and employees.

Action: sql_db_schema

Action Input: "Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track"[0m
Observation: [33;1m[1;3m
CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("Artist

ValueError: An output parsing error occurred. In order to pass this error back to the agent and have it try again, pass `handle_parsing_errors=True` to the AgentExecutor. This is the error: Could not parse LLM output: `This is a well-designed database schema for a music streaming service. Here are some thoughts on the design:

1. Normalization: The schema appears to be normalized, with each table having a primary key and foreign keys to other tables. This helps to reduce data redundancy and improve data integrity.
2. Relationships: The relationships between tables are well-defined, with foreign keys connecting related tables. For example, the Invoice table has foreign keys to the Customer and Track tables, which allows you to link an invoice to a specific customer and track.
3. Data types: The data types for each column are appropriate for the data being stored. For example, the UnitPrice column in the InvoiceLine table is of type NUMERIC(10, 2), which can store decimal values with up to two digits after the decimal point.
4. Indexes: There are no indexes defined in the schema, but it's likely that you would want to add indexes to improve query performance. For example, you might want to add an index on the CustomerId column in the Invoice table to speed up queries that filter by customer.
5. Data consistency: The schema includes constraints to ensure data consistency. For example, the UnitPrice column in the InvoiceLine table has a check constraint to ensure that the value is non-negative.
6. Music-specific features: The schema includes tables and columns that are specific to music, such as the TrackId, AlbumId, and MediaTypeId columns in the Track table. These features help to make the database more tailored to the music streaming service's needs.

Overall, this is a well-designed database schema that should be able to handle the data storage and retrieval needs of a music streaming service.`

In [53]:
agent_output = agent_executor.run("List the names of the employees born after 1973. Share the query used")



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m  Sure! I'd be happy to help you with that.

Question: List the names of the employees born after 1973. Share the query used

Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.

Action: sql_db_list_tables

Action Input: An empty string[0m
Observation: [38;5;200m[1;3mAlbum, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track[0m
Thought:[32;1m[1;3m  Question: List the names of the employees born after 1973. Share the query used

Thought: I should look at the tables in the database to see what I can query. Then I should query the schema of the most relevant tables.

Action: sql_db_schema

Action Input: Employee[0m
Observation: [33;1m[1;3m
CREATE TABLE "Employee" (
	"EmployeeId" INTEGER NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"FirstName" NVARCHAR(20) NOT NULL, 
	"Title" NVARCHAR(30

ValueError: An output parsing error occurred. In order to pass this error back to the agent and have it try again, pass `handle_parsing_errors=True` to the AgentExecutor. This is the error: Could not parse LLM output: `The query "SELECT LastName, FirstName FROM Employee WHERE BirthDate > '1973-08-29'" is correct and does not have any obvious mistakes. However, there are a few best practices that could be followed to make the query more readable and maintainable:

1. Use meaningful table and column aliases: Instead of using the long table and column names, consider using shorter aliases to make the query more readable. For example, you could use "e" for the "Employee" table and "lname" and "fname" for the "LastName" and "FirstName" columns, respectively.
2. Use explicit typing for date comparisons: Instead of using a string literal for the date comparison, consider using an explicit date type to avoid any potential issues with date parsing. For example, you could use a DATE function to convert the string literal to a date type, like this: "SELECT LastName, FirstName FROM Employee WHERE BirthDate > DATE('1973-08-29')".
3. Use a more specific date comparison: Instead of comparing the birth date to a fixed point in time, consider using a more specific date comparison that takes into account the time zone and daylight saving time (if applicable). For example, you could use "SELECT LastName, FirstName FROM Employee WHERE BirthDate > DATE_TRUNC('year', BirthDate) AND BirthDate < DATE_TRUNC('year', BirthDate) + INTERVAL 1 YEAR" to compare the birth date to the current year.
4. Consider using a subquery instead of a join: If you need to filter the results based on a related table, consider using a subquery instead of a join. For example, if you also wanted to list the employees who were hired in a specific year, you could use a subquery like this: "SELECT LastName, FirstName FROM Employee WHERE BirthDate > DATE('1973-08-29') AND HireDate > DATE('2002-01-01')".
5. Use consistent naming conventions: Finally, make sure to use consistent naming conventions throughout the query. For example, use either camelCase or underscores to separate words in column and table names, and avoid using abbreviations or acronyms unless they are clearly defined.`

There appears to be limitations with the output itself as it throws errors. Let's try it with the Huggingface loader

#### Test Llama models with HF loader

In [5]:
from langchain.llms import HuggingFaceHub
from langchain.llms import HuggingFacePipeline
import torch
from transformers import pipeline, LlamaForCausalLM, LlamaTokenizer

In [6]:
os.environ["HUGGINGFACEHUB_API_TOKEN"] = ""

In [8]:
model_id = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = LlamaTokenizer.from_pretrained(model_id)
model = LlamaForCausalLM.from_pretrained(model_id)

Loading checkpoint shards: 100%|██████████| 2/2 [01:02<00:00, 31.02s/it]


In [15]:
pipe = pipeline(
    task = "text-generation", 
    model=model,
    tokenizer=tokenizer,
    max_length=10000
)

Test that the transformers pipeline works first

In [11]:
pipe(
    'I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like?\n',
    do_sample=True,
    top_k=10,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
    max_length=200,
)

[{'generated_text': 'I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like?\n\nAnswer:\n\nBased on your interest in "Breaking Bad" and "Band of Brothers," here are some other shows you might enjoy:\n\n1. "The Sopranos" - This HBO series is a crime drama that explores the life of a New Jersey mob boss, Tony Soprano, as he navigates the criminal underworld and deals with personal and family issues.\n2. "The Wire" - This HBO series is a gritty and realistic portrayal of the drug trade in Baltimore, with a focus on the impact it has on the city and its residents.\n3. "Mad Men" - This AMC series is set in the 1960s and follows the lives of advertising executives on Madison Avenue, exploring them'}]

Now test it with test-to-SQL generation in Langchain

In [17]:
llm_model = HuggingFacePipeline(pipeline=pipe, pipeline_kwargs={"verbose": True})

In [18]:
db = SQLDatabase.from_uri("sqlite:///DBs/Chinook.db")

sql_chain = create_sql_query_chain(llm=llm_model, db=db)

In [19]:
sql_query = sql_chain.invoke({
    "question" : "List the names of the employees born after 1973"
})

sql_query

'SELECT Employee.LastName, Employee.BirthDate FROM Employee \nORDER BY Employee.BirthDate DESC'

It works!!....but not exactly. The query is wrong, it rather produces a query to list employees by birthdate in descending order..

In [20]:
db.run(sql_query)

"[('Peacock', '1973-08-29 00:00:00'), ('Mitchell', '1973-07-01 00:00:00'), ('King', '1970-05-29 00:00:00'), ('Callahan', '1968-01-09 00:00:00'), ('Johnson', '1965-03-03 00:00:00'), ('Adams', '1962-02-18 00:00:00'), ('Edwards', '1958-12-08 00:00:00'), ('Park', '1947-09-19 00:00:00')]"