In [39]:
import duckdb
from dotenv import load_dotenv
import pandas as pd
from langchain_community.utilities import SQLDatabase
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
import sqlite3

In [2]:
load_dotenv(dotenv_path="../.env")

True

In [3]:
llm = ChatOpenAI()

In [4]:
sqlite_alchemy_uri = "sqlite:///../data/Chinook.db"
sqlite_uri = "../data/Chinook.db"

db = SQLDatabase.from_uri(sqlite_alchemy_uri)
sqlite_con = sqlite3.connect(sqlite_uri)

In [60]:
duckdb_con = duckdb.connect(":memory:")

In [29]:
def get_schema(_):
    return db.get_table_info()

In [18]:
def run_query(query):
    return db.run(query)

In [21]:
def get_text_chunks(raw_text):
    text_splitter = CharacterTextSplitter(
        separator="\n",
        chunk_size=1000,
        chunk_overlap=100,
        length_function=len,
    )
    return text_splitter.split_text(raw_text)

In [27]:
def get_vectorstore(text_chunks):
    embeddings = OpenAIEmbeddings()
    return FAISS.from_texts(texts=text_chunks, embedding=embeddings)

In [42]:
def get_conversation_chain(vectorstore):
    memory = ConversationBufferMemory(
        memory_key="chat_history", return_messages=True
    )
    return ConversationalRetrievalChain.from_llm(
        llm=llm, retriever=vectorstore.as_retriever(), memory=memory
    )

In [45]:
def handle_user_question(user_question, conversation):
    return conversation({"question": user_question})

## Attach a sqlite database into duckdb


In [62]:
duckdb_con.execute(f"ATTACH '{sqlite_uri}' AS test (TYPE sqlite)")

<duckdb.duckdb.DuckDBPyConnection at 0x7a078276f730>

In [64]:
duckdb_con.query("USE test; SHOW TABLES;")

┌───────────────┐
│     name      │
│    varchar    │
├───────────────┤
│ Album         │
│ Artist        │
│ Customer      │
│ Employee      │
│ Genre         │
│ Invoice       │
│ InvoiceLine   │
│ MediaType     │
│ Playlist      │
│ PlaylistTrack │
│ Track         │
├───────────────┤
│    11 rows    │
└───────────────┘

## Ask for a query based on user question


In [54]:
template = """Based on the table schema below, write a SQL query that would answer the user's question:
{schema}

Question: {question}
SQL Query:"""

query_prompt = ChatPromptTemplate.from_template(template)

In [55]:
sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | query_prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [9]:
user_question = "what are the top 5 artists by sales?"
sql_chain_result = sql_chain.invoke({"question": user_question})
sql_chain_result

'SELECT ar.Name AS Artist, SUM(il.UnitPrice) AS TotalSales\nFROM Artist ar\nJOIN Album al ON ar.ArtistId = al.ArtistId\nJOIN Track t ON al.AlbumId = t.AlbumId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY ar.Name\nORDER BY TotalSales DESC\nLIMIT 5;'

In [10]:
df = pd.read_sql_query(sql_chain_result, sqlite_con)
df.head()

Unnamed: 0,Artist,TotalSales
0,Iron Maiden,138.6
1,U2,105.93
2,Metallica,90.09
3,Led Zeppelin,86.13
4,Lost,81.59


## Ask for a query based on an user question using vectorstore context


Reference: https://python.langchain.com/docs/expression_language/how_to/passthrough


In [67]:
raw_text = db.get_table_info()
text_chunks = get_text_chunks(raw_text)
vector_store = get_vectorstore(text_chunks)
retriever = vector_store.as_retriever()

In [81]:
template = """Write a SQL query that would answer the user's question based on the context:
{context}

Question: {question}
SQL Query:"""

context_query_prompt = ChatPromptTemplate.from_template(template)

In [82]:
context_sql_chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | context_query_prompt
    | llm.bind(stop=["\nSQLResult:"])
    | StrOutputParser()
)

In [83]:
user_question = "what are the top 5 artists by sales?"
context_sql_chain_result = context_sql_chain.invoke(user_question)
context_sql_chain_result

'SELECT ar.Name AS Artist, SUM(il.UnitPrice * il.Quantity) AS TotalSales\nFROM Artist ar\nJOIN Album al ON ar.ArtistId = al.ArtistId\nJOIN Track t ON al.AlbumId = t.AlbumId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY ar.Name\nORDER BY TotalSales DESC\nLIMIT 5;'

In [84]:
df = pd.read_sql_query(context_sql_chain_result, sqlite_con)
df.head()

Unnamed: 0,Artist,TotalSales
0,Iron Maiden,138.6
1,U2,105.93
2,Metallica,90.09
3,Led Zeppelin,86.13
4,Lost,81.59


## Ask for a Natural language response based on user question


In [11]:
template = """Based on the table schema below, question, sql query write a natural language response:
{schema}

Question: {question}
SQL Query: {query}
SQL Response: {response}"""

prompt_response = ChatPromptTemplate.from_template(template=template)

In [13]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response=lambda vars: run_query(vars["query"]),
    )
    | prompt_response
    | llm.bind(stop=["\nNatural Language Response:"])
    | StrOutputParser()
)

In [14]:
user_question = "how many albums are there in the database?"
full_chain.invoke({"question": user_question})

'There are 347 albums in the database.'

## RAG with chat history and FAISS


In [43]:
raw_text = db.get_table_info()
text_chunks = get_text_chunks(raw_text)
vector_store = get_vectorstore(text_chunks)
conversation = get_conversation_chain(vectorstore=vector_store)

In [56]:
user_question = "write a query that answer the question: how many albums are there in the database?"
response = handle_user_question(user_question, conversation)
response

{'question': 'write a query that answer the question: how many albums are there in the database?',
 'chat_history': [HumanMessage(content='how many albums are there in the database?'),
  AIMessage(content='There are 3 albums in the database.'),
  HumanMessage(content='how many albums are there in the database?'),
  AIMessage(content='There are 3 albums in the database.'),
  HumanMessage(content='how many albums are there in the database?'),
  AIMessage(content='There are 3 albums in the database.'),
  HumanMessage(content='write a query that answer the question: how many albums are there in the database?'),
  AIMessage(content='To find out how many albums are in the database, you can use the following SQL query:\n\n```sql\nSELECT COUNT(*) AS TotalAlbums\nFROM Album;\n```\n\nThis query will return the total number of albums present in the database.')],
 'answer': 'To find out how many albums are in the database, you can use the following SQL query:\n\n```sql\nSELECT COUNT(*) AS TotalAlb

In [57]:
chat_history = response["chat_history"]
for index, message in enumerate(chat_history):
    if index % 2 == 0:
        print(f"User: {message.content}")
    else:
        print(f"Bot: {message.content}")

User: how many albums are there in the database?
Bot: There are 3 albums in the database.
User: how many albums are there in the database?
Bot: There are 3 albums in the database.
User: how many albums are there in the database?
Bot: There are 3 albums in the database.
User: write a query that answer the question: how many albums are there in the database?
Bot: To find out how many albums are in the database, you can use the following SQL query:

```sql
SELECT COUNT(*) AS TotalAlbums
FROM Album;
```

This query will return the total number of albums present in the database.


In [59]:
df = pd.read_sql_query(
    "SELECT COUNT(*) AS TotalAlbums FROM Album;", sqlite_con
)
df.head()

Unnamed: 0,TotalAlbums
0,347


In [65]:
duckdb_con.query("SELECT COUNT(*) AS TotalAlbums FROM Album;").df()

Unnamed: 0,TotalAlbums
0,347
