<a href="https://colab.research.google.com/github/bhandarytejas/Text-to-SQL-Project/blob/main/Text2SQL_Project_Main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install required packages
!pip install transformers datasets streamlit pandas plotly sqlparse

Collecting streamlit
  Downloading streamlit-1.50.0-py3-none-any.whl.metadata (9.5 kB)
Collecting pydeck<1,>=0.8.0b4 (from streamlit)
  Downloading pydeck-0.9.1-py2.py3-none-any.whl.metadata (4.1 kB)
Downloading streamlit-1.50.0-py3-none-any.whl (10.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.1/10.1 MB[0m [31m50.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pydeck-0.9.1-py2.py3-none-any.whl (6.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m60.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pydeck, streamlit
Successfully installed pydeck-0.9.1 streamlit-1.50.0


In [2]:
# Mount Google Drive for peristence
from google.colab import drive
drive.mount('/content/drive')

#Create project structure
!mkdir -p "/content/drive/MyDrive/text2sql_project"
!mkdir -p "/content/drive/MyDrive/text2sql_project/data"
!mkdir -p "/content/drive/MyDrive/text2sql_project/src"
!mkdir -p "/content/drive/MyDrive/text2sql_project/notebook"
!mkdir -p "/content/drive/MyDrive/text2sql_project/tests"

Mounted at /content/drive


In [3]:
!mkdir -p "/content/drive/MyDrive/text2sql_project/tests"

In [4]:
# !git config --global user.email "bhandarytejas92@gmail.com"
# !git config --global user.name "Tejas B"

# Navigate to your project directory
# %cd /content/drive/MyDrive/text2sql_project

# # Initialize git (only if you want version control in Colab)
# !git init

In [5]:
# Set Working Directory
import os
os.chdir('/content/drive/MyDrive/text2sql_project')

In [6]:
import pandas as pd
import sqlite3
import numpy as np
from datetime import datetime, timedelta
import plotly.express as px
import plotly.graph_objects as go

In [7]:
def create_retail_database():
  """Create a retail database for testing"""
  conn = sqlite3.connect('/content/drive/MyDrive/text2sql_project/data/retail_sample.db')

  #Create sample data
  np.random.seed(42)

  #Customer Table
  customer_df = pd.DataFrame({
      'customer_id': range(1, 1001),
      'name': [f'Customer_{i}' for i in range(1,1001)],
      'email': [f'customer{i}@gmail.com' for i in range(1,1001)],
      'city': np.random.choice(['New York', 'Los Angeles', 'Chicago', 'Houston', 'Phoneix'], 1000),
      'registration_date': pd.date_range('2022-01-01', periods=1000, freq='D')
  }
  )

  #Products Table
  products_df = pd.DataFrame({
      'product_id': range(1, 101),
      'name': [f'Product_{i}' for i in range(1,101)],
      'category': np.random.choice(['Electronics', 'Clothing', 'Home', 'Books', 'Sports'], 100),
      'price': np.random.uniform(10, 1000, 100)
  }
  )

  #Orders Table
  orders_data = []
  for order_id in range(1, 5001):
    customer_id = np.random.randint(1, 1001)
    order_date = pd.Timestamp('2023-01-01') + timedelta(days=np.random.randint(0, 365))
    orders_data.append({
        'order_id': order_id,
        'customer_id': customer_id,
        'order_date': order_date,
        'total_amount': np.random.uniform(20, 1000)
    })
  orders_df = pd.DataFrame(orders_data)

  # Save to database
  customer_df.to_sql('customers', conn, if_exists='replace', index=False)
  products_df.to_sql('products', conn, if_exists='replace', index=False)
  orders_df.to_sql('orders', conn, if_exists='replace', index=False)

  print("Database created Successfully!")
  print(f"Created {len(customer_df)} customer records")
  print(f"Created {len(products_df)} product records")
  print(f"Created {len(orders_df)} order records")

  conn.close()
  return "Database Ready!"


# Getting Retails Database
create_retail_database()

Database created Successfully!
Created 1000 customer records
Created 100 product records
Created 5000 order records


'Database Ready!'

In [8]:
db_path = '/content/drive/MyDrive/text2sql_project/data/retail_sample.db'
conn = sqlite3.connect(db_path)

tables_query = "SELECT name FROM sqlite_master WHERE type='table';"
tables = pd.read_sql(tables_query, conn)
print(tables)

# Check sample data
print("\n🔍 Sample customer data:")
sample_data = pd.read_sql("SELECT * FROM customers LIMIT 5", conn)
print(sample_data)

# Check database schema
print("\n📋 Customer table schema:")
schema = pd.read_sql("PRAGMA table_info(customers)", conn)
print(schema[['name', 'type']])

conn.close()

        name
0  customers
1   products
2     orders

🔍 Sample customer data:
   customer_id        name                email     city    registration_date
0            1  Customer_1  customer1@gmail.com  Houston  2022-01-01 00:00:00
1            2  Customer_2  customer2@gmail.com  Phoneix  2022-01-02 00:00:00
2            3  Customer_3  customer3@gmail.com  Chicago  2022-01-03 00:00:00
3            4  Customer_4  customer4@gmail.com  Phoneix  2022-01-04 00:00:00
4            5  Customer_5  customer5@gmail.com  Phoneix  2022-01-05 00:00:00

📋 Customer table schema:
                name       type
0        customer_id    INTEGER
1               name       TEXT
2              email       TEXT
3               city       TEXT
4  registration_date  TIMESTAMP


In [9]:
#Basic SQL Generator
def generate_sql (question):
  """Convert natural language question to SQL query"""
  q = question.lower()

  # Pattern 1: Count queries
  if "how many" in q or "count" in q:
    if "customer" in q:
      return "SELECT COUNT(*) as total_customers FROM customers;"
    elif "city" in q or "cities" in q:
      return "SELECT COUNT(DISTINCT city) as total_cities FROM customers;"

  # Pattern 2: Top/Best queries (most specific first)
  if ("top" in q or "best" in q) and ("city" in q or "cities" in q):
      return "SELECT city, COUNT(*) as customer_count FROM customers GROUP BY city ORDER BY customer_count DESC LIMIT 10;"

  if "top" in q or "best" in q:
      if "customer" in q:
        return "SELECT * FROM customers ORDER BY registration_date DESC LIMIT 10;"

  # Pattern 3: Unique/Distinct values
  if ("what" in q or "which" in q) and ("city" in q or "cities" in q):
      if "have" in q or "are" in q or "do we" in q:
        return "SELECT DISTINCT city, COUNT(*) as customer_count FROM customers GROUP BY city ORDER BY city;"

  # Pattern 4: List/Show queries
  if "list" in q or "show" in q:
      if "city" in q or "cities" in q:
        return "SELECT DISTINCT city FROM customers ORDER BY city;"
      if "customer" in q:
        return "SELECT * FROM customers LIMIT 20;"

  # Pattern 5: Recent queries
  if "recent" in q or "latest" in q:
        return "SELECT * FROM customers ORDER BY registration_date DESC LIMIT 10;"


  # Default fallback
  return "SELECT * FROM customers LIMIT 10;"

In [10]:
# STEP 2: SQL Execution Engine
def execute_sql_safely(sql_query, db_path, max_rows=100):
  """
  Execute SQL queries with safety checks and error handling
  """
  # Step 2a: Safety checks - prevent dangerous operations
  dangerous_keywords = ['DROP', 'DELETE', 'INSERT', 'UPDATE', 'ALTER', 'CREATE', 'TRUNCATE']
  sql_upper = sql_query.upper()

  for keyword in dangerous_keywords:
    if keyword in sql_upper:
      return {
          'success': False,
          'error': f"Dangerous operation '{keyword}' not allowed for safety reasons",
          'data': None
      }
  # Step 2b: Execute the query with error handling
  try:
    conn = sqlite3.connect(db_path)

    # Add LIMIT if not present (prevent huge result sets)
    if 'LIMIT' not in sql_upper:
      sql_query = sql_query.rstrip(';') + f" LIMIT {max_rows};"

    print("Query: ", sql_query)
    # Execute and fetch results
    result_df = pd.read_sql(sql_query, conn)
    conn.close()

    return {
        'success': True,
        'error': None,
        'data': result_df,
        'row_count': len(result_df)
    }

  except Exception as e:
    return {
        'success': False,
        'error': f"❌ SQL Error: {str(e)}",
        'data': None
    }

In [11]:
# # Test the execution engine
# db_path = '/content/drive/MyDrive/text2sql_project/data/retail_sample.db'
# print("🧪 Testing SQL Execution Engine:\n")
# print("="*60)

# # Test 1: Simple query
# print("\n📝 Test 1: Simple COUNT query")
# sql1 = "SELECT COUNT(*) as total FROM customers"
# result1 = execute_sql_safely(sql1, db_path)

# if result1['success']:
#     print(f"✅ Success! Found {result1['row_count']} row(s)")
#     print(result1['data'])

# else:
#     print(result1['error'])

# # Test 2: GROUP BY query
# print("\n📝 Test 2: GROUP BY query")
# sql2 = "SELECT city, COUNT(*) as customer_count FROM customers GROUP BY city ORDER BY customer_count DESC"
# result2 = execute_sql_safely(sql2, db_path)

# if result2['success']:
#     print(f"✅ Success! Found {result2['row_count']} row(s)")
#     print(result2['data'].head(10))
# else:
#     print(result2['error'])

# # Test 3: Safety check - dangerous query
# print("\n📝 Test 3: Safety check (should FAIL)")
# sql3 = "DROP TABLE customers"
# result3 = execute_sql_safely(sql3, db_path)

# if result3['success']:
#     print("⚠️ WARNING: Dangerous query was allowed!")
# else:
#     print(f"✅ Safety working: {result3['error']}")

# print("\n" + "="*60)

In [12]:
# COMPLETE SYSTEM: Question → SQL → Results
def ask_question(question, db_path):
  """
  Complete pipeline: Question to Answer
  """
  print(f"\n Questions: {question}")
  print("-"*60)

  #Generate SQL
  sql = generate_sql(question)
  print(f"🔧 Generated SQL:\n{sql}\n")

  result = execute_sql_safely(sql, db_path)

  if result['success']:
    print(f"✅ Query executed successfully!")
    print(f"📊 Found {result['row_count']} row(s)\n")
    print(result['data'])
    return result['data']
  else:
    print(result['error'])
    return None

In [15]:
# STEP 3: Automatic Visualization
import plotly.express as px
import plotly.graph_objects as go

def create_visualization(df, question):
    """
    Automatically create appropriate visualization based on data type
    """

    # Check if we have data
    if df is None or df.empty:
        print("📊 No data to visualize")
        return None

    # Get column information
    numeric_cols = df.select_dtypes(include=['int64', 'float64', 'number']).columns.tolist()
    text_cols = df.select_dtypes(include=['object']).columns.tolist()
    date_cols = df.select_dtypes(include=['datetime64']).columns.tolist()

    print(f"📊 Creating visualization...")

    # Case 1: Single number (like COUNT)
    if len(df) == 1 and len(numeric_cols) == 1:
        value = df[numeric_cols[0]].iloc[0]
        fig = go.Figure(go.Indicator(
            mode = "number",
            value = value,
            title = {"text": question},
            number = {'font': {'size': 60}}
        ))
        fig.show()
        return fig

    # Case 2: Category + Number (like city counts) - BAR CHART
    elif len(text_cols) >= 1 and len(numeric_cols) >= 1:
        x_col = text_cols[0]
        y_col = numeric_cols[0]

        # Limit to top 15 for readability
        if len(df) > 15:
            df_plot = df.head(15)
            title = f"Top 15: {question}"
        else:
            df_plot = df
            title = question

        fig = px.bar(
            df_plot,
            x=x_col,
            y=y_col,
            title=title,
            color=y_col,
            color_continuous_scale='viridis'
        )
        fig.update_layout(xaxis_tickangle=-45)
        fig.show()
        return fig

    # Case 3: Date/Time series - LINE CHART
    elif len(date_cols) >= 1 and len(numeric_cols) >= 1:
        fig = px.line(
            df,
            x=date_cols[0],
            y=numeric_cols[0],
            title=question,
            markers=True
        )
        fig.show()
        return fig

    # Case 4: Multiple numeric columns - show first two
    elif len(numeric_cols) >= 2:
        fig = px.scatter(
            df,
            x=numeric_cols[0],
            y=numeric_cols[1],
            title=question
        )
        fig.show()
        return fig

    # Case 5: Just text data - show as table (already displayed)
    else:
        print("📋 Data shown in table format above")
        return None


# Enhanced ask_question function with visualization
def ask_question_with_viz(question, db_path):
    """
    Complete pipeline: Question → SQL → Results → Visualization
    """
    print(f"\n💬 Question: {question}")
    print("-" * 60)

    # Step 1: Generate SQL
    sql = generate_sql(question)
    print(f"🔧 Generated SQL:\n{sql}\n")

    # Step 2: Execute SQL
    result = execute_sql_safely(sql, db_path)

    if result['success']:
        print(f"✅ Query executed successfully!")
        print(f"📊 Found {result['row_count']} row(s)\n")

        # Display results
        if result['row_count'] <= 10:
            print(result['data'].to_string(index=False))
        else:
            print(result['data'].head(10).to_string(index=False))
            print(f"\n... and {result['row_count'] - 10} more rows")

        # Step 3: Create visualization
        print()
        create_visualization(result['data'], question)

        return result['data']
    else:
        print(result['error'])
        return None



In [16]:
test_questions = [
    "How many customers do we have?",
    "Show me the top 10 cities by customer count",
    "What cities do we have customers in?",
]

for question in test_questions:
    result = ask_question_with_viz(question, db_path)
    print("\n" + "=" * 60)


💬 Question: How many customers do we have?
------------------------------------------------------------
🔧 Generated SQL:
SELECT COUNT(*) as total_customers FROM customers;

Query:  SELECT COUNT(*) as total_customers FROM customers LIMIT 100;
✅ Query executed successfully!
📊 Found 1 row(s)

 total_customers
            1000

📊 Creating visualization...




💬 Question: Show me the top 10 cities by customer count
------------------------------------------------------------
🔧 Generated SQL:
SELECT COUNT(*) as total_customers FROM customers;

Query:  SELECT COUNT(*) as total_customers FROM customers LIMIT 100;
✅ Query executed successfully!
📊 Found 1 row(s)

 total_customers
            1000

📊 Creating visualization...




💬 Question: What cities do we have customers in?
------------------------------------------------------------
🔧 Generated SQL:
SELECT DISTINCT city, COUNT(*) as customer_count FROM customers GROUP BY city ORDER BY city;

Query:  SELECT DISTINCT city, COUNT(*) as customer_count FROM customers GROUP BY city ORDER BY city LIMIT 100;
✅ Query executed successfully!
📊 Found 5 row(s)

       city  customer_count
    Chicago             190
    Houston             206
Los Angeles             190
   New York             210
    Phoneix             204

📊 Creating visualization...



