In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import matplotlib.patches as patches
import mysql.connector

# Database connection function
def connect_to_database(host, user, password, database):
    try:
        connection = mysql.connector.connect(
            host=host,
            user=user,
            password=password,
            database=database
        )
        print("Database connection successful.")
        return connection
    except mysql.connector.Error as err:
        print(f"Error: {err}")
        return None

# Fetch data from the database
def fetch_data_from_database(connection, query):
    try:
        df = pd.read_sql(query, connection)
        print("Data fetched successfully.")
        return df
    except Exception as e:
        print(f"Error fetching data: {e}")
        return None

# Create a pivot table
def create_pivot_table(df):
    pivot_table = pd.pivot_table(
        df,
        values='Case No',  # Use 'Case No' to count the number of cases
        index='Case Status',
        columns='Gender',
        aggfunc='count',
        fill_value=0
    )
    return pivot_table

# Plot data
def plot_data(pivot_table):
    # Step 2: Create a figure with subplots
    fig, axs = plt.subplots(2, 1, figsize=(12, 12))  # Two rows, one column

    # Step 3: Create a stacked bar plot
    pivot_table.plot(kind='bar', stacked=True, color=sns.color_palette("husl", len(pivot_table.columns)), ax=axs[0])

    # Step 4: Add titles and labels for the bar plot
    axs[0].set_title('Count of Cases by Case Status and Gender', fontsize=15, weight='bold')
    axs[0].set_xlabel('Case Status', fontsize=14, labelpad=10)
    axs[0].set_ylabel('Number of Cases', fontsize=14, labelpad=10)
    axs[0].set_xticklabels(pivot_table.index, rotation=45, ha='right')  # Rotate x-axis labels for better readability
    axs[0].legend(title='Gender', fontsize=12)  # Add legend for Gender

    # Step 5: Add grid lines for better clarity
    axs[0].grid(axis='y', linestyle='--', alpha=0.7)

    # Step 6: Create a table for the pivot table in the second subplot
    axs[1].axis('off')  # Hide the axis for the pivot table

    # Create a table from the pivot table
    table = axs[1].table(cellText=pivot_table.values,
                         rowLabels=pivot_table.index,
                         colLabels=pivot_table.columns,
                         loc='center')

    # Add a title to the table
    axs[1].set_title('Pivot Table of Cases by Status and Gender', fontsize=16, fontweight='bold')

    # Add comments to explain the plot
    comment_plot = "This plot shows the count of cases by status for each gender. Each color represents a different gender."
    plt.text(0.5, -0.75, comment_plot, ha='center', va='center', fontsize=12, transform=axs[0].transAxes)

    comment_table = "This table shows the count of cases by status for each gender. The first table represents the case status, and the second and third represent the gender."
    plt.text(0.5, -0.09, comment_table, ha='center', va='center', fontsize=12, transform=axs[1].transAxes)

    # Add comments under each plot
    fig.text(0.5, 0.000000099, 'Report analysis by Mediport', ha='center', va='center', fontsize=13, fontweight='bold')
    fig.text(0.5, 1.1000001, 'Vagus Hospital', ha='center', va='center', fontsize=20, fontweight='bold')

    # Step 7: Add a border around the entire figure
    border = patches.Rectangle((-0.03, -0.03), 1.17, 1.17, transform=fig.transFigure, color='black', linewidth=3, fill=False)
    fig.patches.append(border)

    # Adjust layout
    plt.tight_layout()

    # Save the combined figure as a PDF
    plt.savefig('combined_counts_and_pivot_table.pdf', bbox_inches='tight', pad_inches=0.5)
    plt.show()  # Show the plot

    print('Both the stacked bar plot and pivot table have been saved as combined_counts_and_pivot_table.pdf')

# Main function to execute the workflow
def main():
    # Database connection details
    host = "localhost"
    user = "root"
    password = "password"
    database = "hospital_db"

    # SQL query to fetch data
    query = "SELECT `Case No`, `Case Status`, `Gender` FROM cases_table"

    # Connect to the database
    connection = connect_to_database(host, user, password, database)

    if connection:
        # Fetch data
        df = fetch_data_from_database(connection, query)
        if df is not None:
            # Create pivot table
            pivot_table = create_pivot_table(df)
            # Plot data
            plot_data(pivot_table)
        connection.close()

# Run the main function
if __name__ == "__main__":
    main()


Error: 2003 (HY000): Can't connect to MySQL server on 'localhost:3306' (10061)


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import matplotlib.patches as patches
import mysql.connector
from mysql.connector import Error

class DatabaseConnectionSingleton:
    _instance = None

    @classmethod
    def get_instance(cls, host, user, password, database):
        """
        Returns a singleton instance of the database connection.
        Reinitializes the connection if it is closed or stale.
        """
        if cls._instance is None:
            cls._instance = cls._initialize_connection(host, user, password, database)
        else:
            try:
                # Check if the connection is still active
                cls._instance.ping(reconnect=True)
            except Error as e:
                print(f"Reinitializing database connection due to error: {e}")
                cls._instance = cls._initialize_connection(host, user, password, database)
        return cls._instance

    @classmethod
    def _initialize_connection(cls, host, user, password, database):
        """
        Initializes a new database connection.
        """
        try:
            connection = mysql.connector.connect(
                host=host,
                user=user,
                password=password,
                database=database
            )
            print("Database connection initialized.")
            return connection
        except Error as e:
            print(f"Error initializing database connection: {e}")
            return None

    @classmethod
    def close_instance(cls):
        """
        Closes the database connection if it exists.
        """
        if cls._instance is not None:
            try:
                cls._instance.close()
                print("Database connection closed.")
            except Error as e:
                print(f"Error closing database connection: {e}")
            finally:
                cls._instance = None

class DataProcessor:
    def __init__(self, connection):
        self.connection = connection

    def fetch_data(self, query):
        """
        Fetches data from the database using the provided query.
        """
        try:
            df = pd.read_sql(query, self.connection)
            print("Data fetched successfully.")
            return df
        except Exception as e:
            print(f"Error fetching data: {e}")
            return None

    def create_pivot_table(self, df):
        """
        Creates a pivot table from the DataFrame.
        """
        pivot_table = pd.pivot_table(
            df,
            values='Case No',
            index='Case Status',
            columns='Gender',
            aggfunc='count',
            fill_value=0
        )
        return pivot_table

    def plot_data(self, pivot_table):
        """
        Plots the data using the pivot table.
        """
        fig, axs = plt.subplots(2, 1, figsize=(12, 12))

        # Stacked bar plot
        pivot_table.plot(kind='bar', stacked=True, color=sns.color_palette("husl", len(pivot_table.columns)), ax=axs[0])
        axs[0].set_title('Count of Cases by Case Status and Gender', fontsize=15, weight='bold')
        axs[0].set_xlabel('Case Status', fontsize=14, labelpad=10)
        axs[0].set_ylabel('Number of Cases', fontsize=14, labelpad=10)
        axs[0].set_xticklabels(pivot_table.index, rotation=45, ha='right')
        axs[0].legend(title='Gender', fontsize=12)
        axs[0].grid(axis='y', linestyle='--', alpha=0.7)

        # Table
        axs[1].axis('off')
        table = axs[1].table(cellText=pivot_table.values,
                             rowLabels=pivot_table.index,
                             colLabels=pivot_table.columns,
                             loc='center')
        axs[1].set_title('Pivot Table of Cases by Status and Gender', fontsize=16, fontweight='bold')

        # Comments and styling
        fig.text(0.5, 0.000000099, 'Report analysis by Mediport', ha='center', va='center', fontsize=13, fontweight='bold')
        fig.text(0.5, 1.1000001, 'Vagus Hospital', ha='center', va='center', fontsize=20, fontweight='bold')
        border = patches.Rectangle((-0.03, -0.03), 1.17, 1.17, transform=fig.transFigure, color='black', linewidth=3, fill=False)
        fig.patches.append(border)
        plt.tight_layout()
        plt.savefig('combined_counts_and_pivot_table.pdf', bbox_inches='tight', pad_inches=0.5)
        plt.show()
        print('Both the stacked bar plot and pivot table have been saved as combined_counts_and_pivot_table.pdf')

# Main function
def main():
    # Database connection details
    host = "localhost"
    user = "root"
    password = "password"
    database = "hospital_db"

    # SQL query
    query = "SELECT `Case No`, `Case Status`, `Gender` FROM cases_table"

    # Get database connection
    connection = DatabaseConnectionSingleton.get_instance(host, user, password, database)

    if connection:
        # Process data
        processor = DataProcessor(connection)
        df = processor.fetch_data(query)
        if df is not None:
            pivot_table = processor.create_pivot_table(df)
            processor.plot_data(pivot_table)

        # Close database connection
        DatabaseConnectionSingleton.close_instance()

# Run the main function
if __name__ == "__main__":
    main()
