In [1]:
from langchain_core.documents import Document
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_groq import ChatGroq
from langchain_cohere import CohereEmbeddings, CohereRerank
from langchain.document_loaders import JSONLoader
from langchain_community.vectorstores import FAISS
import os
import json

from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())


True

#### Api key

In [2]:
groq_api_key = os.getenv("groq_api_key")
if groq_api_key is None:
    raise ValueError("Did not find groq_api_key, please add an environment variable `GROQ_API_KEY` which contains it, or pass `groq_api_key` as a named parameter.")

cohere_api_key = os.getenv("cohere_api_key")
if cohere_api_key is None:
    raise ValueError("Did not find cohere_api_key, please add an environment variable `COHERE_API_KEY`")

In [3]:
query = "How many different aircraft models are there in the aircrafts_data table?"

## LLM

In [4]:
# llm
llm = ChatGroq(
    model="llama3-8b-8192",
    api_key=groq_api_key,
    temperature=0,
    model_kwargs={
        "top_p": 0.95,
        "frequency_penalty": 0.1,
        "presence_penalty": 0.2,
    },
    verbose=True
)

# Embedding
embeddings = CohereEmbeddings(cohere_api_key=cohere_api_key)

# db
db = SQLDatabase.from_uri("sqlite:///../data/travel.sqlite")

## Indexing Table

In [5]:

toolkit = SQLDatabaseToolkit(db=db, llm=llm)
context = toolkit.get_context()

In [6]:
docs_table = JSONLoader(
    file_path="../data/rag-schema/table.jsonl",
    jq_schema='.', 
    text_content=False, 
    json_lines=True).load()

In [7]:
db_table = FAISS.from_documents(documents=docs_table, embedding=embeddings)
retriever_table = db_table.as_retriever(search_type='mmr', search_kwargs={'k': 5, 'lambda_mult': 1})

In [8]:
matched_documents_table = retriever_table.invoke(query)
matched_tables = []

for document in matched_documents_table:
    page_content = document.page_content
    page_content = json.loads(page_content)
    table_name = page_content['table']
    matched_tables.append(table_name)

print(f'Matched tables = {matched_tables}')

Matched tables = ['aircrafts_data', 'seats', 'flights', 'boarding_passes', 'tickets']


## Indexing Columns

In [9]:
docs_columns = JSONLoader(
    file_path="../data/rag-schema/column.jsonl",
    jq_schema='.', 
    text_content=False, 
    json_lines=True).load()

In [10]:
db_columns = FAISS.from_documents(documents=docs_columns, embedding=embeddings)
retriever_columns = db_columns.as_retriever(search_type='similarity', search_kwargs={'k': 20})

In [11]:
matched_columns = retriever_columns.invoke(query)

matched_columns_filtered = []
for doc in matched_columns:
    page_content = json.loads(doc.page_content)
    table_ = page_content["table_name"]
    if table_ in matched_tables:
        matched_columns_filtered.append(page_content)

In [12]:
matched_columns_cleaned = []

for doc in matched_columns_filtered:
    table_name = doc['table_name']
    column_name = doc['column_name']
    data_type = doc['data_type']
    matched_columns_cleaned.append(f'table_name={table_name}|column_name={column_name}|data_type={data_type}')
    
matched_columns_cleaned = '\n'.join(matched_columns_cleaned)

## Text-to-SQL generation

In [13]:
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser


In [14]:
prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
    You are a SQL master expert specializing in writing complex SQL queries for SQLite. Your task is to construct a SQL query based on the provided information. Follow these strict rules:

    USER_QUERY: {query}
    -------
    MATCHED_SCHEMA: {matched_schema}
    -------

    Please construct a SQL query using the MATCHED_SCHEMA and the USER_QUERY provided above.
    Generate ONLY the SQL query. Do not provide any explanations, comments, or additional text. 
    The goal is to determine the availability of hotels based on the provided info. 

    IMPORTANT: Use ONLY the column names (column_name) mentioned in MATCHED_SCHEMA. DO NOT USE any other column names outside of this. 
    IMPORTANT: Associate column_name mentioned in MATCHED_SCHEMA only to the table_name specified under MATCHED_SCHEMA.
    NOTE: Use SQL 'AS' statement to assign a new name temporarily to a table column or even a table wherever needed. 

    <|eot_id|><|start_header_id|>assistant<|end_header_id|>
    """,
    input_variables=["query", "matched_schema"],
)

sql_gen = prompt | llm | StrOutputParser()
result_sql = sql_gen.invoke(
    {
        "query": query,
        "matched_schema": matched_columns_cleaned,
    }
)
response_sql = db.run(result_sql)

In [15]:
prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
    You are a travel assistant your task transform the SQL query result into a natural, conversational response.


    User's Question:
    ----------------
    {user_query}

    SQL Result:
    ----------------
    {sql_response}

    
    <|eot_id|><|start_header_id|>assistant<|end_header_id|>
    """,
        input_variables=["user_query", "sql_response"],
    )

response_llm = prompt | llm | StrOutputParser()
response = response_llm.invoke({"user_query": query, "sql_response": response_sql})
print(response)

According to our database, there are 9 different aircraft models in the aircrafts_data table.


# Human response friendly

In [31]:
def get_response(query: str) -> str:
    matched_table = retriever_table.invoke(query)
    matched_tables = []
    for document in matched_table:
        page_content = document.page_content
        page_content = json.loads(page_content)
        table_name = page_content['table']
        matched_tables.append(table_name)

    matched_columns = retriever_columns.invoke(query)
    matched_columns_filtered = []
    for doc in matched_columns:
        page_content = json.loads(doc.page_content)
        table_ = page_content["table_name"]
        if table_ in matched_tables:
            matched_columns_filtered.append(page_content)

    matched_columns_cleaned = []
    for doc in matched_columns_filtered:
        table_name = doc['table_name']
        column_name = doc['column_name']
        data_type = doc['data_type']
        matched_columns_cleaned.append(f'table_name={table_name}|column_name={column_name}|data_type={data_type}')
        
    matched_columns_cleaned = '\n'.join(matched_columns_cleaned)

    result_sql = sql_gen.invoke(
        {
            "query": query,
            "matched_schema": matched_columns_cleaned,
        }
    )
    
    response_sql = db.run(result_sql)
    response = response_llm.invoke({"user_query": query, "sql_response": response_sql})
    print(f"Query: {query}")
    print("="*50)
    print(f"SQL: {result_sql}")
    print("="*50)
    print(f"Response: \n{response}")

In [17]:
# Easy
get_response("How many different aircraft models are there in the aircrafts_data table? and what are the models?")

Query: How many different aircraft models are there in the aircrafts_data table? and what are the models?
Response: 
According to the data, there are 9 different aircraft models in the aircrafts_data table. The models are:

* Boeing 777-300
* Boeing 767-300
* Sukhoi Superjet-100
* Airbus A320-200
* Airbus A321-200
* Airbus A319-100
* Boeing 737-300
* Cessna 208 Caravan
* Bombardier CRJ-200

Let me know if you'd like to know more about any of these models or if you have any other questions!


In [18]:
# Easy
get_response("What is the total number of bookings made in April 2024?")

OperationalError: (sqlite3.OperationalError) near "FROM": syntax error
[SQL: SELECT COUNT(*) 
FROM bookings 
WHERE EXTRACT(MONTH FROM book_date) = 4 AND EXTRACT(YEAR FROM book_date) = 2024;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

In [19]:
# moderate
get_response("Which airport has the most departing flights?")

Query: Which airport has the most departing flights?
Response: 
According to my research, the airport with the most departing flights is actually EuroAirport Basel-Mulhouse-Freiburg. It's a major hub that connects passengers to various destinations around the world. If you're planning a trip, you might want to consider flying out of this airport for your next adventure!


In [20]:
# moderate
get_response("What is the average ticket price for Business class flights?")

Query: What is the average ticket price for Business class flights?
Response: 
According to our data, the average ticket price for Business class flights is approximately $51,143.


In [21]:
# moderate
get_response("How many car rentals are available in Basel for the luxury price tier?")

Query: How many car rentals are available in Basel for the luxury price tier?
Response: 
It looks like there aren't any car rentals available in Basel that fit the luxury price tier. It's possible that the luxury options might be booked up or not available at the moment. I'd be happy to help you explore other options or check availability for different price tiers if you'd like!


In [None]:
# Difficult 
# runtime error
# get_response("What is the most popular seat (by frequency of booking) for each aircraft type?")

In [23]:
# Difficult 
get_response("Calculate the average delay time for flights that arrived late, grouped by departure airport.")

OperationalError: (sqlite3.OperationalError) no such function: DATEDIFF
[SQL: SELECT 
    a.departure_airport, 
    AVG(DATEDIFF(a.actual_arrival, a.scheduled_arrival)) AS average_delay
FROM 
    flights a
WHERE 
    a.status = 'arrived' AND 
    DATEDIFF(a.actual_arrival, a.scheduled_arrival) > 0
GROUP BY 
    a.departure_airport;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

In [25]:
# Difficult
get_response("For each passenger who has made multiple bookings, what is the total amount spent on tickets and the average duration of their trips?")

OperationalError: (sqlite3.OperationalError) no such column: p.passenger_id
[SQL: SELECT 
    p.passenger_id, 
    SUM(tf.amount) AS total_amount, 
    AVG(DATEDIFF(tf.actual_arrival, tf.actual_departure)) AS average_duration
FROM 
    tickets t 
JOIN 
    ticket_flights tf ON t.book_ref = tf.ticket_no 
JOIN 
    bookings b ON t.passenger_id = b.passenger_id 
WHERE 
    t.passenger_id IN (
        SELECT 
            passenger_id 
        FROM 
            tickets 
        GROUP BY 
            passenger_id 
        HAVING 
            COUNT(*) > 1
    )
GROUP BY 
    p.passenger_id;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

In [27]:
# Very Difficult
get_response("Identify any circular trips in the database (where a passenger starts and ends at the same airport with intermediate stops) and calculate the total duration of each such trip.")

OperationalError: (sqlite3.OperationalError) wrong number of arguments to function GROUP_CONCAT()
[SQL: WITH 
    flight_data AS (
        SELECT 
            f1.arrival_airport, 
            f1.departure_airport, 
            f1.scheduled_departure, 
            f1.scheduled_arrival, 
            f1.actual_departure, 
            f1.actual_arrival, 
            f1.flight_id, 
            f1.aircraft_code, 
            f1.status
        FROM 
            flights f1
        WHERE 
            f1.status = 'landed'
    ),
    trip_data AS (
        SELECT 
            t1.id, 
            t1.booked, 
            f2.arrival_airport, 
            f2.departure_airport
        FROM 
            trip_recommendations t1
        JOIN 
            ticket_flights t2 ON t1.id = t2.flight_id
        JOIN 
            flights f2 ON t2.flight_id = f2.flight_id
        WHERE 
            f2.status = 'landed'
    ),
    circular_trips AS (
        SELECT 
            t1.id, 
            t1.booked, 
            f1.arrival_airport, 
            f1.departure_airport, 
            f1.scheduled_departure, 
            f1.scheduled_arrival, 
            f1.actual_departure, 
            f1.actual_arrival, 
            f1.flight_id, 
            f1.aircraft_code, 
            f1.status
        FROM 
            trip_data t1
        JOIN 
            flight_data f1 ON t1.arrival_airport = f1.departure_airport AND t1.departure_airport = f1.arrival_airport
    )
SELECT 
    id, 
    booked, 
    GROUP_CONCAT(DISTINCT arrival_airport, ' -> ', departure_airport, ' (', scheduled_departure, ' - ', scheduled_arrival, ')', ' (', actual_departure, ' - ', actual_arrival, ')') AS trip_details, 
    SUM(DATEDIFF(scheduled_arrival, scheduled_departure)) AS total_duration
FROM 
    circular_trips
GROUP BY 
    id, 
    booked
ORDER BY 
    total_duration DESC;]
(Background on this error at: https://sqlalche.me/e/20/e3q8)

In [32]:
# Very difficult
get_response("For each city, calculate the ratio of incoming tourists (based on hotel bookings by non-locals) to outgoing tourists (based on flight bookings by locals) for each month. Identify any significant seasonal patterns.")

Query: For each city, calculate the ratio of incoming tourists (based on hotel bookings by non-locals) to outgoing tourists (based on flight bookings by locals) for each month. Identify any significant seasonal patterns.
SQL: WITH 
    hotel_bookings AS (
        SELECT 
            location, 
            strftime('%Y-%m', checkin_date) AS month, 
            COUNT(*) AS hotel_bookings
        FROM 
            hotels
        WHERE 
            booked > 0
        GROUP BY 
            location, 
            strftime('%Y-%m', checkin_date)
    ),
    flight_bookings AS (
        SELECT 
            departure_airport, 
            strftime('%Y-%m', scheduled_arrival) AS month, 
            COUNT(*) AS flight_bookings
        FROM 
            flights
        WHERE 
            actual_arrival IS NOT NULL
        GROUP BY 
            departure_airport, 
            strftime('%Y-%m', scheduled_arrival)
    ),
    combined_bookings AS (
        SELECT 
            h.location, 
            h