In [None]:
!pip install mysql-connector-python sqlalchemy matplotlib
import pandas as pd
import matplotlib.pyplot as plt
from sqlalchemy import create_engine
from db_setup import setup_database

# Set up the MySQL database
engine = setup_database()

# Function to run a single query and return a DataFrame
def run_query(query, params=None):
    try:
        with engine.connect() as connection:
            return pd.read_sql(query, connection, params=params)
    except Exception as e:
        print(f"Error executing query: {query}\nError: {str(e)}")
        return None

# Function to run dynamic region query
def run_dynamic_region_query():
    valid_regions = ['Central', 'East', 'South', 'West']
    region = input("Enter a region (Central, East, South, West) : ")
    while region not in valid_regions:
        print("Invalid region! Please choose Central, East, South, or West")
        region = input("Enter a region (Central, East, South, West) : ")
    query = """
    SELECT DATE_FORMAT(order_date, '%Y-%m') AS month, ROUND(SUM(sales), 2) AS total_sales
    FROM super_store_data
    WHERE region = %s
    GROUP BY month
    ORDER BY month;
    """
    df = run_query(query, params=(region,))
    return df

# Function to run multiple queries from a file
def RUN_QUERIES(filename):
    with open(filename, 'r') as file:
        script = file.read()
        queries = [q.strip() for q in script.split(';') if q.strip()]  # Skipping empty queries
        result = []
        for query in queries:
            if '%s' not in query:  # Skipping dynamic query
                df = run_query(query)
                if df is not None:
                    result.append(df)
        return result

# Function to plot 3 (2 , 5,  8 ) query results
def plot_results(results):
    # Check if results list is valid
    if not results:
        print("No results to plot")
        return

    # Plot Query 2: Sales by category (Bar chart)
    if len(results) > 1 and not results[1].empty:
        try:
            categories = results[1]['category']
            sales = results[1]['SUM(sales)']
            plt.figure(figsize=(8, 5))
            plt.bar(categories, sales, color='blue')
            plt.title('Total Sales by Category')
            plt.xlabel('Category')
            plt.ylabel('Sales ($)')
            plt.tight_layout()
            plt.savefig('sales_by_category.png')
            plt.close()
            print("Saved sales_by_category.png")
        except Exception as e:
            print(f"Error plotting Query 2: {str(e)}")

    # Plot Query 5: Top 5 states by profit (Horizontal bar chart)
    if len(results) > 4 and not results[4].empty:
        try:
            top_5 = results[4].head(5)
            states = top_5['state']
            profits = top_5['SUM(profit)']
            plt.figure(figsize=(8, 5))
            plt.barh(states, profits, color='orange')
            plt.title('Top 5 States by Profit')
            plt.xlabel('Profit ($)')
            plt.ylabel('State')
            plt.tight_layout()
            plt.savefig('profit_by_state.png')
            plt.close()
            print("Saved profit_by_state.png")
        except Exception as e:
            print(f"Error plotting Query 5: {str(e)}")

    # Plot Query 8: Top 3 states by profit margin (Pie chart)
    if len(results) > 7 and not results[7].empty:
        try:
            states = results[7]['state']
            margins = results[7]['profit_margin']
            plt.figure(figsize=(6, 6))
            plt.pie(margins, labels=states, autopct='%1.1f%%', colors=['blue', 'orange', 'green'])
            plt.title('Top 3 States by Profit Margin')
            plt.tight_layout()
            plt.savefig('profit_margin_by_state.png')
            plt.close()
            print("Saved profit_margin_by_state.png")
        except Exception as e:
            print(f"Error plotting Query 8: {str(e)}")

# Run the queries
results = RUN_QUERIES("sql_queries.sql")

# Run the dynamic region query
dynamic_result = run_dynamic_region_query()
if dynamic_result is not None:
    results.append(dynamic_result)

# Print results
for i, df in enumerate(results):
    print(f"\nResult of query {i + 1}:")
    print(df)

# Plot the results
plot_results(results)

print("Analysis complete , Check sales_by_category.png , profit_by_state.png , and profit_margin_by_state.png in the project folder")