In [1]:
from langchain_openai import ChatOpenAI
import os


In [2]:
llm = ChatOpenAI(
    openai_api_key=os.environ.get("OPENAI_API_KEY"), model_name="gpt-4o-mini"
)


In [3]:
from langchain_community.utilities import SQLDatabase

db_uri = (
    "sqlite:///ai_project.db"
)

database = SQLDatabase.from_uri(db_uri, sample_rows_in_table_info=5)


In [4]:
from langchain_core.prompts import (
    HumanMessagePromptTemplate,
    SystemMessagePromptTemplate,
    ChatPromptTemplate,
)

system_template = """You are an SQL translator that translates the user's question into a SQL query. 
Form the query using SQLLite syntax and reference to the following table schema: 
{schema}

Return the query without any formatting."""
system_prompt = SystemMessagePromptTemplate.from_template(system_template)

human_template = "Question: {question}"
human_prompt = HumanMessagePromptTemplate.from_template(human_template)

sql_chat_prompt = ChatPromptTemplate.from_messages([system_prompt, human_prompt])


In [5]:
def get_schema(_):
    schema = database.get_table_info()
    return schema


In [6]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

sql_chain = (
    RunnablePassthrough.assign(schema=get_schema)
    | sql_chat_prompt
    | llm
    | StrOutputParser()
)


In [7]:
user_question = "Which trains will go to Mumbai?"
sql_chain.invoke({"question": user_question})


"SELECT * FROM TrainSchedule WHERE Destinations LIKE '%Mumbai%';"

In [24]:
system_template = """You will be provided with a SQL query and its response.
Based the table schema and question provided below derive the answer from the response. 
Be precise and concise do not provide any formatting.

Additional instructions:
1. The DaysOfOperation column contains either a list of days like Mon, Tue, Wed, etc., or Daily if the train operates every day of the week.
2. Mention both the train name as well as the train number when asked for information regarding any train.

Schema: {schema}
Question: {question}"""
system_prompt = SystemMessagePromptTemplate.from_template(system_template)

human_template = """Query: {query}
Response: {response}"""
human_prompt = HumanMessagePromptTemplate.from_template(human_template)

answer_chat_prompt = ChatPromptTemplate.from_messages([system_prompt, human_prompt])


In [25]:
def run_query(query):
    return database.run(query)


In [26]:
full_chain = (
    RunnablePassthrough.assign(query=sql_chain).assign(
        schema=get_schema,
        response=lambda vars: run_query(vars["query"]),
    )
    | answer_chat_prompt
    | llm
    | StrOutputParser()
)


In [27]:
user_question = "Which trains will go to Mumbai?"
full_chain.invoke({"question": user_question})


'Konkan Kanya Express, Train Number 11014'

In [28]:
user_question = "Which is the fastest train to Bangalore?"
full_chain.invoke({"question": user_question})


'The fastest train to Bangalore is Brindavan Express.'

In [29]:
user_question = "When is the next train to Surat? Can you tell me when it will arrive and on which platform?"
full_chain.invoke({"question": user_question})


'The next train to Surat is the Golden Temple Mail (Train Number 12903), which will arrive at 18:00:00 on platform 2.'

In [30]:
user_question = "Which trains do not function on weekends?"
print(full_chain.invoke({"question": user_question}))


The trains that do not function on weekends are:
1. Rajdhani Express (Train Number: 12301)
2. Karnataka Express (Train Number: 12622)
3. Shatabdi Express (Train Number: 22406)
4. Andhra Pradesh Express (Train Number: 12716)
5. Gorakhpur Express (Train Number: 12555)
6. Magadh Express (Train Number: 12391)
7. Intercity Express (Train Number: 12012)
