<a href="https://colab.research.google.com/github/bnehirartan/Developing-a-Chatbot-using-VS-Code-Colab-SQLite-Gemini-API-and-Gradio/blob/main/hw4_group6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install gradio

In [None]:
import os, json
import google.generativeai as genai
import gradio as gr
import sqlite3
from google.generativeai.types import HarmCategory, HarmBlockThreshold
from google.colab import userdata
import traceback
from typing import List, Dict, Any

In [None]:
from google.colab import files
files.upload()  # Manually upload gemini_helper.py which is from the hw2

Saving gemini_helper.py to gemini_helper.py




In [None]:
import gemini_helper as gh

In [None]:
genai.configure(api_key=userdata.get('GEMINI_API_KEY'))

In [None]:
DB_PATH = os.path.join(os.path.expanduser("~"), "Documents", "northwind.db")

In [None]:
chat = model.start_chat(history=[])

In [None]:
def get_db_connection():
    """
    Gets the database connection and lists tables to ensure the database is accessible.

    Raises:
    Exception: If there is an error connecting to the database.
    """
    if not os.path.exists(DB_PATH):
        print(f"Error: Database file not found! Please check the path: {DB_PATH}")
        return

    try:
        conn = sqlite3.connect(DB_PATH, timeout=10)
        conn.row_factory = sqlite3.Row  # this enables column access by name
        return conn
    except sqlite3.Error as e:
        print(f"Database connection error: {e}")
        return None

In [None]:
def get_table_schema():
    """Extract the schema from the database to inform the model."""
    try:
        conn = get_db_connection()
        if not conn:
            return "Could not connect to database to extract schema."

        cursor = conn.cursor()

        # get a list of all tables
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = cursor.fetchall()

        schema = {}
        for table in tables:
            table_name = table[0]
            cursor.execute(f"PRAGMA table_info({table_name});")
            columns = cursor.fetchall()

            schema[table_name] = {
                "columns": [{"name": col[1], "type": col[2]} for col in columns],
                "primary_key": next((col[1] for col in columns if col[5] == 1), None)
            }

        conn.close()
        return json.dumps(schema, indent=2)
    except Exception as e:
        print(f"Error getting schema: {e}")
        traceback.print_exc()
        return json.dumps({"error": str(e)})

# get the schema once at the beginning
print("Fetching database schema...")
DB_SCHEMA = get_table_schema()
print("Schema fetched successfully!")


# define the JSON schema for structured output
json_schema = {
    "type": "object",
    "properties": {
        "sql_query": {
            "type": "string",
            "description": "The SQL query to execute on the database"
        },
        "explanation": {
            "type": "string",
            "description": "A brief explanation of what the SQL query does"
        }
    },
    "required": ["sql_query", "explanation"]
}

In [None]:
def execute_sql_query(sql_query: str) -> List[Dict[str, Any]]:
    """Execute the given SQL query and return the results as a list of dictionaries."""
    conn = get_db_connection()
    if not conn:
        return [{"error": "Could not connect to database"}]

    try:

        cursor = conn.cursor()
        cursor.execute(sql_query)

        # Convert the results to a list of dictionaries
        columns = [col[0] for col in cursor.description]
        results = [dict(zip(columns, row)) for row in cursor.fetchall()]

        conn.close()
        return results
    except sqlite3.Error as e:
        conn.close()
        return [{"error": f"SQL error: {e}", "query": sql_query}]

In [None]:
def calculate_gemini_cost(input_tokens, output_tokens):
    """
    Calculates the cost based on the number of input and output tokens using the Gemini pricing model.

    Pricing Model (Cost per Million Tokens):
    - Input:
        * Up to 128,000 tokens: $0.075 per 1M tokens
        * More than 128,000 tokens: $0.15 per 1M tokens
    - Output:
        * Up to 128,000 tokens: $0.30 per 1M tokens
        * More than 128,000 tokens: $0.60 per 1M tokens

    Parameters:
    input_tokens (int): The number of input tokens.
    output_tokens (int): The number of output tokens.

    Returns:
    float: The total cost based on the input and output tokens.
    """
    # calcualte the input cost
    if input_tokens <= 128_000:
        input_cost = (input_tokens / 1_000_000) * 0.075
    else:
        input_cost = (input_tokens / 1_000_000) * 0.15

     # calculate the output cost
    if output_tokens <= 128_000:
        output_cost = (output_tokens / 1_000_000) * 0.30
    else:
        output_cost = (output_tokens / 1_000_000) * 0.60

    return input_cost + output_cost #total cost

In [None]:
def generate_sql_from_nl(user_query: str) -> Dict[str, Any]:
    """Generate SQL from natural language using Gemini API with structured output."""
    try:
        # simplified schema for the prompt to reduce token usage
        simplified_schema = json.loads(DB_SCHEMA)
        schema_prompt = "Tables and their columns:\n"
        for table, details in simplified_schema.items():
            columns = [col['name'] for col in details['columns']]
            schema_prompt += f"- {table}: {', '.join(columns)}\n"

        system_prompt =f"""You are an expert SQL assistant that helps generate SQL queries for a Northwind company database. Your sole goal is to understand the user's request and generate a valid SQL query for the Northwind database.
        The database contains the following tables and columns:
        {schema_prompt}

        **Do:**

        * Generate valid SQLite SQL queries based on user requests
        * Use only tables and columns that exist in the schema
        * Handle requests in both English and Turkish
        * Ensure proper JOIN syntax when combining tables
        * Keep SQL queries simple and efficient
        * Always return responses in the required JSON format

        **Do not:**

        * Use tables or columns that don't exist in the schema
        * Provide alternative solutions unless the user specifically asks
        * Include unnecessary complexity in queries
        * Respond with anything other than the required JSON format

        **Example 1:**
        **You:**  "Is there anything else I can help you with today?"
        **Customer:** "Show me all customers from Germany"
        **You:** "
  "sql_query": "SELECT * FROM Customers WHERE Country = 'Germany';",
  "explanation": "This query selects all customer records where the Country field equals 'Germany'."
"
        **Example 2:**
        **You:**  "Is there anything else I can help you with today?"
        **Customer:** "En pahalı 5 ürünü listele"
        **You:** "
  "sql_query": "SELECT ProductName, Price FROM Products ORDER BY Price DESC LIMIT 5;",
  "explanation": "Bu sorgu, ürünleri fiyatlarına göre azalan sırada sıralar ve en pahalı 5 ürünün adını ve fiyatını getirir."
"

        **Example 3:**
        **You:**  "Is there anything else I can help you with today?"
        **Customer:** "Which employee handled the most orders?"
        **You:** "
  "sql_query": "SELECT e.EmployeeID, e.FirstName, e.LastName, COUNT(o.OrderID) as OrderCount FROM Employees e JOIN Orders o ON e.EmployeeID = o.EmployeeID GROUP BY e.EmployeeID ORDER BY OrderCount DESC LIMIT 1;",
  "explanation": "This query counts the orders handled by each employee and returns the employee who handled the most orders."
"

        **Example 4:**
        **You:**  "Is there anything else I can help you with today?"
        **Customer:** "Hangi tedarikçi en çok ürün sağlıyor?"
        **You:** "
  "sql_query": "SELECT s.SupplierID, s.SupplierName, COUNT(p.ProductID) as ProductCount FROM Suppliers s JOIN Products p ON s.SupplierID = p.SupplierID GROUP BY s.SupplierID ORDER BY ProductCount DESC LIMIT 1;",
  "explanation": "Bu sorgu, her tedarikçinin sağladığı ürün sayısını hesaplar ve en çok ürün sağlayan tedarikçiyi döndürür."
"

        **Note**
        Continue the interaction according to the language used by the user.
        *If the user writes in Turkish, respond with a Turkish explanation.
        *If the user writes in English, respond with an English explanation.

       """

        generation_config = {
          "temperature": 0.1,
           "top_p": 0.95,
            "top_k": 64,
            "max_output_tokens": 8192,
            "response_mime_type": "application/json",
             }

        model = genai.GenerativeModel(
            model_name="gemini-1.5-flash-latest",
            generation_config=generation_config,
            safety_settings={
                HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
            },
            system_instruction=system_prompt
        )

             # using the gemini_helper to handle rate limits and retry logic
        response = gh.api_request_with_retry(
            chat.send_message,
            f"Schema: {DB_SCHEMA}\nUser input: {user_query}",
            generation_config=generation_config
        )

        if response is None:
          return {
            "sql_query": "",
            "explanation": "❌ API Error: Could not get a response from Gemini."
          }

         # Calculate token usage
        input_tokens = len(user_query.split())  # Simple approximation
        output_tokens = len(response.text.split())
        cost = gh.calculate_gemini_cost(input_tokens, output_tokens)  # Use the cost function

        print(f"Cost for the request: ${cost:.4f}")


        print("Sending request to Gemini API...")
        response = model.generate_content(
            user_query,
            generation_config={"response_schema": json_schema}
        )
        print("Received response from Gemini API!")

        try:
            result = json.loads(response.text)
            return result
        except json.JSONDecodeError:
            print(f"Failed to parse JSON. Raw response: {response.text}")
            return {
                "sql_query": "",
                "explanation": "Failed to parse response from AI model. Please try again."
            }

        return result

    except Exception as e:
        print(f"Error in generate_sql_from_nl: {e}")
        traceback.print_exc()
        return {
            "sql_query": "",
            "explanation": f"Error communicating with AI model: {str(e)}. Please try again."
        }

In [None]:
def create_chat_interface():
    """Create and launch the Gradio chat interface."""
    with gr.Blocks(css=".chatbot-container { max-width: 500px; margin: auto; } footer { visibility: hidden; }") as demo:
        gr.Markdown("# Northwind DB Chatbot")
        gr.Markdown("Ask questions about the company database in English or Turkish!")

        with gr.Row():
            with gr.Column(scale=3):
                # initialize chatbot with a welcome message
                chatbot = gr.Chatbot(value=[
                    (None, "Hello! I'm here to answer your questions about the Northwind database. How can I help you?")
                ], height=500)

                with gr.Row():
                    msg = gr.Textbox(placeholder="Type your question here... (e.g., 'What are the 5 most expensive products?')", label="Your Question")
                    send_btn = gr.Button("Send")

                clear = gr.Button("Clear")

            with gr.Column(scale=1):
                gr.Markdown("### Example Questions")
                example_btn1 = gr.HTML("Almanya'daki müşteriler kimler?")
                example_btn2 = gr.HTML("En pahalı 5 ürünü listele")
                example_btn3 = gr.HTML("Hangi tedarikçi en çok ürün sağlıyor?")
                example_btn4 = gr.HTML("Alice Mutton ürününün fiyatı nedir?")

         def example_query_click(query):
            """Handle example query click and populate the input box."""
            msg.value = query
            return msg

        example_btn1.click(example_query_click, inputs=example_btn1, outputs=msg)
        example_btn2.click(example_query_click, inputs=example_btn2, outputs=msg)
        example_btn3.click(example_query_click, inputs=example_btn3, outputs=msg)
        example_btn4.click(example_query_click, inputs=example_btn4, outputs=msg)

        def user(user_message, history):
            """Handle user input and update the chat history."""
            return "", history + [(user_message, None)]

        def bot(history):
            """Process user query and generate response."""
            user_message = history[-1][0]
            history = history[:-1]

            try:
                # generate SQL from natural language
                response = generate_sql_from_nl(user_message)

                sql_query = response.get("sql_query", "")
                explanation = response.get("explanation", "")

                if not sql_query:
                    history.append((user_message, "I couldn’t generate a valid SQL query for this request. Please ask questions related to the Northwind DB."))
                    return history

                # execute the SQL query
                query_results = execute_sql_query(sql_query)

                if len(query_results) == 1 and "error" in query_results[0]:
                    error_message = query_results[0]["error"]
                    result = f"Error while executing query: {error_message}\n\nAttempted SQL: {sql_query}"
                else:
                    # format the results
                    formatted_results = json.dumps(query_results, indent=2, ensure_ascii=False)

                    # create a full response with explanation, SQL, and results
                    result = f"""**What I understood:**
                    {explanation}
                    **SQL Query:**
                    ```sql
                    {sql_query}
                    ```
                    **Sonuçlar:**
                    ```json
                    {formatted_results if query_results else "No results found."}
                    ```
                    {len(query_results)} record(s) found."""

                history.append((user_message, result))
            except Exception as e:
                history.append((user_message, f"An unexpected error occurred: {str(e)}"))
                traceback.print_exc()

            return history

        def clear_chat():
            """Clear the chat history."""
            return []

        # connect the interface components
        send_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
            bot, chatbot, chatbot
        )

        msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
            bot, chatbot, chatbot
        )

        clear.click(clear_chat, None, chatbot)

    return demo

# create and launch the chat interface
print("Launching interface...")
chat_interface = create_chat_interface()
chat_interface.launch(share=True, debug=True)