In [None]:
import streamlit as st
from pathlib import Path
from langchain.agents import create_sql_agent
from langchain.sql_database import SQLDatabase
from langchain.agents.agent_types import AgentType
from langchain.callbacks import StreamlitCallbackHandler
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from sqlalchemy import create_engine # map the output that come form sql 
import sqlite3
from langchain_groq import ChatGroq

st.set_page_config(page_icon="🦜🔗", page_title="Chat SQL", layout="wide")
st.title("Langchain : Chat with SQL DB")

INJECTION_WARNING = """
⚠️ 🧠 Warning: This app uses the groq to generate SQL queries.
This is for demo purposes only and should not be used maliciously.
"""

LOCALDB = "USE_LOCALDB"
MYSQL = "USE_MYSQL"

radio_opt = ["USE_LOCALDB", "Connect to MySQL"]

selected_option = st.sidebar.radio(label="Select Database", options=radio_opt)


if radio_opt.index(selected_option) == 1:
    db_uri = MYSQL
    mysql_host = st.sidebar.text_input("MySQL Host", "localhost")
    mysql_port = st.sidebar.text_input("MySQL Port", "3306")
    mysql_user = st.sidebar.text_input("MySQL User", "root")
    mysql_password = st.sidebar.text_input("MySQL Password", "" , type= "password")
    mysql_database = st.sidebar.text_input("MySQL Database", "test")
else:
    db_uri = LOCALDB

api_key = st.sidebar.text_input("Groq API Key", "", type="password")

if not db_uri:
    st.info("Please enter the database details")

if not api_key:
    st.info("Please enter the groq api key")

## LLM model 

llm = ChatGroq(groq_api_key=api_key , model_name = "Llama3-8b-8192" , streaming=True)

@st.cache_resource(ttl = "2h")
def configure_database(db_uri, mysql_host = None, mysql_port = None, mysql_user = None, mysql_password = None, mysql_database = None):
    if db_uri == LOCALDB:
        db_file_path = (Path(__file__).parent / "Student.db").absolute()
        print(db_file_path)
        creator = lambda : sqlite3.connect(f"file:{db_file_path}?mode=ro", uri=True)
        return SQLDatabase(create_engine("sqlite:///" , creator = creator))
    elif db_uri == MYSQL:
        if not (mysql_host and mysql_port and mysql_user and mysql_password and mysql_database):
            raise ValueError("Please provide all the mysql details")
        return SQLDatabase(create_engine(f"mysql+mysqlconnector://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"))
    else:
        raise ValueError(f"Invalid database URI: {db_uri}")


if db_uri == MYSQL:
    db = configure_database(db_uri, mysql_host, mysql_port, mysql_user, mysql_password, mysql_database)
else:
    db = configure_database(db_uri)

## toolkit
toolkit = SQLDatabaseToolkit(db , llm)

agent = create_sql_agent(llm, toolkit, verbose=True ,agent_type = AgentType.ZERO_SHOT_REACT_DESCRIPTION)


if "messages" in st.session_state or st.sidebar.button("Clear messages History"):
    st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you ?"}]

for msg in st.session_state.messages:
    st.chat_message(msg["role"]).write(msg["content"])

user_query = st.chat_input(placeholder="Ask anything about the database")


if user_query:
    st.session_state.messages.append({"role": "user", "content": user_query})
    st.chat_message("user").write(user_query)
    with st.chat_message("assistant"):
        stremlit_callback_handler = StreamlitCallbackHandler(st.container())
        response = agent.run(user_query, callbacks=[stremlit_callback_handler])
        st.session_state.messages.append({"role": "assistant", "content": response})
        st.write(response)
