# Import libraries & create spark session

In [None]:
# Import libraries
from pymysql import connect
import pandas as pd
import os
import warnings
warnings.filterwarnings('ignore')
from getpass import getpass
import re
import findspark
findspark.init()
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import col, count

# Create SparkSession
## spark.driver.memory need to be modified according to user's needs and hardware
## spark.driver.extraClassPath need to be modified according to user's connector file location
spark = SparkSession.builder.config("spark.driver.memory", "28g") \
    .config("spark.driver.extraClassPath", "C:\Program Files (x86)\MySQL\Connector J 8.0\mysql-connector-j-8.0.33.jar") \
    .config("spark.sql.parquet.int96RebaseModeInWrite", "LEGACY") \
    .appName("sparketl").getOrCreate() # appName can be changed as follow preference of user

display(spark)

# Connect to database

In [None]:
# Connect to database
IP = getpass("Please enter database IP: ")
User = getpass("Please enter database User: ")
Password = getpass("Please enter database Password: ")
try:
    db = connect(host = IP,
             user = User,
             passwd = Password)

    # Define cursor object
    cur = db.cursor()
    print("Successfully connect to database!", '\n')

    # Define query to see all the database available in the MYSQL server
    query = 'show databases'
    cur.execute(query)
    print("Showing all databases available...", '\n')

    # Fetch all the databases
    db = cur.fetchall()

    # Iterate through all the databases
    for data in db:
        data = ''.join(data)
        print(data)
    print('\n', "Total databases in the connection: ", len(db))
except Exception as e:
    print(str(e))

## Fetching all table in selected database

In [None]:
## Getting all databases in the connection
db_data = []
for data in db:
    data = ''.join(data)
    db_data.append(data)
userin = input("Please select the database to be used: ")

while userin not in db_data:
    userin = input("Please enter the database to be used correctly: ")

db = connect(host=IP, user=User, passwd=Password, database=userin)
cur = db.cursor()

# Define query to fetch all table names
query = "SHOW TABLES"
cur.execute(query)

# Fetch all table names
tables = cur.fetchall()
print("Database", userin, "selected! The tables available in the selected database are: ", '\n')

# Iterate through all the tables
for table in tables:
    table = ''.join(table)
    print(table)
print('\n')
print("Total tables in the database: ", len(tables))

In [None]:
# Request whether the user wants to transform all tables with data into data frames or query on single tables
choice = input("Do you wish to get all tables transformed into data frames? (y/n): ")
while choice.lower() not in ['y', 'n', 'yes', 'no']:
    print('\n')
    choice = input("Please enter your choice correctly! (y/n): ")
    print('\n')

if choice.lower() in ['n', 'no']:
    choice2 = input("Do you want to continue to explore any table or end the query? (y/n): ")
    while choice2.lower() not in ['y', 'n', 'yes', 'no']:
        choice2 = input("Please enter your choice correctly! (y/n): ")

    if choice2.lower() in ['n', 'no']:
        table2 = [tb[0] for tb in tables]
        print("Query end!")

    elif choice2.lower() in ['y', 'yes']:
        choice3 = input("Which table do you want to query for: ")
        table2 = [tb[0] for tb in tables]
        while choice3 not in table2:
            choice3 = input("Table not found! Kindly enter your choice again: ")

        df = spark.read.format("jdbc").option("url", f"jdbc:mysql://{IP}/{userin}").option("user", User).option("password", Password).option("dbtable", f"`{userin}`.`{choice3}`").load()
        df.show()
        print('\n')

    choice4 = input("Do you want to continue querying and displaying tables? (y/n): ")
    while choice4.lower() not in ['y', 'n', 'yes', 'no']:
        choice4 = input("Please enter your choice correctly! (y/n): ")

    if choice4.lower() in ['n', 'no']:
        print('\n')
        print("Query end!")

    while choice4.lower() in ['y', 'yes']:
        choice5 = input("Please enter the table you want to query for, else press n/no: ")

        if choice5.lower() in ['n', 'no']:
            print('Query end!')
            break

        while choice5 not in table2:
            choice5 = input("Table not found! Kindly enter your choice again: ")

        df = spark.read.format("jdbc").option("url", f"jdbc:mysql://{IP}/{userin}").option("user", User).option("password", Password).option("dbtable", f"`{userin}`.`{choice5}`").load()
        df.show()

elif choice.lower() in ['y', 'yes']:
    print('\n')
    print("Extracting those tables with data inside...")
    print('\n')

    # Export tables into Spark DataFrames
    dfs = {}
    for table in tables:
        table = ''.join(table)
        df = spark.read.format("jdbc").option("url", f"jdbc:mysql://{IP}/{userin}").option("user", User).option("password", Password).option("dbtable", f"`{userin}`.`{table}`").load()
        if df.count() > 0:
            dfs[table] = df

    # Access the DataFrames
    for table, df in dfs.items():
        print("Table: {}".format(table))
        print("Shape: Rows - {}, Columns - {}".format(df.count(), len(df.columns)))
        print("-----------------------------", '\n')

    print("Total Spark DataFrames with data inside: ", len(dfs))

# View table in database selected

In [None]:
# Request which table user want to view
if choice.lower() in ['y', 'yes']:
    choice6 = input("Do you want to query any table? (y/n) ")
    while choice6.lower() not in ['y', 'n', 'yes', 'no']:
        choice6 = input("Please enter your choice correctly! (y/n): ")
    if choice6.lower() in ['n', 'no']:
        print("Table query session end!")
    if choice6.lower() in ['y', 'yes']:
        keys = list(dfs.keys())
        que = input("Which table to be displayed (One table per query): ")
        while que not in keys:
            print(que, " not valid!", '\n')
            que = input("Please enter table name to be displayed: ")
        print("Table ", que, ": ", '\n')
        dfs[que].show(dfs[que].count())
        choice7 = input("Do you want to query again? (y/n) ")
        while choice7.lower() not in ['y', 'yes', 'n', 'no']:
            choice7 = input("Invalid response! Kindly enter again: ")
        if choice7.lower() in ['n', 'no']:
            print("Query end! Proceed to next step...")
        while choice7.lower() in ['y', 'yes']:
            que = input("Which table to be displayed (One table per query): ")
            while que not in keys:
                print(que, " not valid!")
                que = input("Please enter table name to be displayed: ")
            print("Table ", que, ": ", '\n')
            dfs[que].show(dfs[que].count())
            choice7 = input("Do you want to continue display table: (y/n) ")
            while choice7.lower() not in ['y', 'n', 'yes', 'no']:
                choice7 = input("Invalid response! Kindly enter again: ")    
            if choice7.lower() in ['n', 'no']:
                print("Query end! Proceed to next step...")

In [None]:
# Define fuction to query relationship between tables
def tb_colq(choice8):
    tables2 = []
    for table in tables:
        table = ''.join(table)
        tables2.append(table)
    while choice8 not in tables2:
        print(choice8, " is invalid!")
        choice8 = input("Please enter again: ")
    # Prepare the SQL query with proper placeholders
    query = "SELECT TABLE_NAME, COLUMN_NAME, CONSTRAINT_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE WHERE REFERENCED_TABLE_SCHEMA = %s AND REFERENCED_TABLE_NAME = %s;"

    # Execute the query with the provided parameters
    cur.execute(query, (userin, choice8))

    # Fetch all the results
    results = cur.fetchall()

    # Print the results
    tb_rel = pd.DataFrame(results, columns = ['TABLE_NAME', 'COLUMN_NAME', 'CONSTRAINT_NAME', 'REFERENCED_TABLE_NAME', 'REFERENCED_COLUMN_NAME'])
    return print(tb_rel)

choice9 = input("Do you want to query for relationship between tables? (y/n) ")
while choice9.lower() not in ['y', 'n', 'yes', 'no']:
    choice9 = input("Please enter your choice correctly! (y/n): ")
if choice9.lower() in ['n', 'no']:
    print("Query end! Proceed to next step...")
else:
    
    # Prompt the user for input
    choice8 = input("Please enter the related table for the query: ")
    print('\n')
    tb_colq(choice8)
    choice10 = input("Do you want to continue checking for other table? (y/n) ")
    while choice10.lower() not in ['y', 'n', 'yes', 'no']:
        choice10 = input("Please enter your choice correctly! (y/n): ")
    if choice10.lower() in ['n', 'no']:
        print('\n')
        print("Query end! Proceed to next step...")
    while choice10.lower() in ['y', 'yes']:
        choice8 = input("Please enter the related table for the query: ")
        print('\n')
        tb_colq(choice8)
        choice11 = input("Do you want to continue checking for other table? (y/n) ")
        while choice11.lower() not in ['y', 'n', 'yes', 'no']:
            choice11 = input("Please enter your choice correctly! (y/n): ")
        if choice11.lower() in ['n', 'no']:
            print("Query end! Proceed to next step...")
            break

# Join tables in database

In [None]:
# Define join df function
def joindf(intbsplit):
    dfs_for_join = []
    
    for i in intbsplit:
        if i in dfs:
            dfs_for_join.append(dfs[i])
        elif i in dfs_joined:
            dfs_for_join.append(dfs_joined[i])
    
    alias = ['df' + str(i+1) for i in range(len(dfs_for_join))]
    joined_df = dfs_for_join[0].alias(alias[0])
    
    for i in range(len(dfs_for_join) - 1):
        tb1 = alias[i]
        tb2 = alias[i + 1]
        join_type = input(f"Please select a merge method for joining {tb1} and {tb2}: (left/right/inner/cross/full/left anti/left semi/right anti/right semi) ")
        while join_type.lower() not in ['left', 'right', 'full', 'inner', 'cross', 'left anti', 'left semi', 'right anti', 'right semi']:
            join_type = input("Kindly re-enter your decision: (left/right/inner/cross/full/left anti/left semi/right anti/right semi) ")

        if join_type.lower() == 'left':
            join_type = 'left'
        elif join_type.lower() == 'right':
            join_type = 'right'
        elif join_type.lower() == 'full':
            join_type = 'full'
        elif join_type.lower() == 'inner':
            join_type = 'inner'
        elif join_type.lower() == 'cross':
            join_type = 'cross'
        elif join_type.lower() == 'left anti':
            join_type = 'left_anti'
        elif join_type.lower() == 'left semi':
            join_type = 'left_semi'
        elif join_type.lower() == 'right anti':
            join_type = 'right_anti'
        elif join_type.lower() == 'right semi':
            join_type = 'right_semi'

        col1 = input(f"Which column to join from {tb1}: ")
        col2 = input(f"Which column to join from {tb2}: ")
        
        joined_df = joined_df.join(dfs_for_join[i + 1].alias(tb2), col(tb1 + f".{col1}") == col(tb2 + f".{col2}"), how=join_type)
        
    nmjndf = input("Please enter a name for the joined table: ")
    dfs_joined[nmjndf] = joined_df
    return print("Joined Data Frame:", nmjndf, '\n'), joined_df.show(), print("Joined dataframe contains ", joined_df.count(), "rows, and ", len(joined_df.columns), "columns", '\n')

# Request user whether want to merge tables
choice12 = input("Do you want to join any tables in the database and convert it into a pandas DataFrame? (y/n): ")
while choice12.lower() not in ['y', 'n', 'yes', 'no']:
    print('\n')
    choice12 = input("Please enter your choice correctly! (y/n): ")
    print('\n')

if choice12.lower() in ['n', 'no']:
    print('\n')
    print("Session End!")

elif choice12.lower() in ['y', 'yes']:
    if choice.lower() in ['y', 'yes']:
        dfs_joined = {}
        userintb = input("Which tables need to be joined? (Enter at least two tables separated by commas): ")
        intbsplit = re.split(r',\s*', userintb)
        while len(intbsplit) < 2:
            userintb = input("Please enter the tables that need to be joined? (Enter at least two tables separated by commas): ")
            intbsplit = re.split(r',\s*', userintb)
        for i, table in enumerate(intbsplit):
            while table not in dfs:
                print(table, "is not in the schema! Kindly enter a valid table name: ")
                table = input()
                intbsplit[i] = table
        joindf(intbsplit)
        
        # Request user whether there are any other tables to join
        choice13 = input("Are there any other tables in the database to be joined? (y/n) ")
        while choice13.lower() not in ['y', 'n', 'yes', 'no']:
            choice13 = input("Please enter your choice correctly! (y/n): ")
        if choice13.lower() in ['n', 'no']:
            print("Session end!")
        while choice13.lower() in ['y', 'yes']:
            userintb = input("Which tables need to be joined? (Enter at least two tables separated by commas): ")
            intbsplit = re.split(r',\s*', userintb)
            while len(intbsplit) < 2:
                userintb = input("Please enter the tables that need to be joined? (Enter at least two tables separated by commas): ")
                intbsplit = re.split(r',\s*', userintb)
            for i, table in enumerate(intbsplit):
                while table not in dfs and table not in dfs_joined:
                    print(table, "is not in the schema! Kindly enter a valid table name: ")
                    table = input()
                    intbsplit[i] = table
            joindf(intbsplit)
            
            choice13 = input("Are there any other tables to be joined? (y/n) ")
            while choice13.lower() not in ['y', 'n', 'yes', 'no']:
                choice13 = input("Please enter your choice correctly! (y/n): ")
            if choice13.lower() in ['n', 'no']:
                print("Session end!")
                break
    
    elif choice2.lower() in ['y', 'n', 'yes', 'no']:
        dfs_joined = {}
        userintb = input("Which tables need to be joined? (Enter at least two tables separated by commas): ")
        intbsplit = re.split(r',\s*', userintb)
        while len(intbsplit) < 2:
            userintb = input("Please enter the tables that need to be joined? (Enter at least two tables separated by commas): ")
            intbsplit = re.split(r',\s*', userintb)
        for i, table in enumerate(intbsplit):
            while table not in table2:
                print(table, "is not in the schema! Kindly enter a valid table name: ")
                table = input()
                intbsplit[i] = table
        joindf(intbsplit)
        
        
        # Request user whether there are any other tables to join
        choice14 = input("Are there any other tables to be joined? (y/n) ")
        while choice14.lower() not in ['y', 'n', 'yes', 'no']:
            choice14 = input("Please enter your choice correctly! (y/n): ")
        if choice14.lower() in ['n', 'no']:
            print("Session end!")
        while choice14.lower() in ['y', 'yes']:
            userintb = input("Which tables need to be joined? (Enter at least two tables separated by commas): ")
            intbsplit = re.split(r',\s*', userintb)
            while len(intbsplit) < 2:
                userintb = input("Please enter the tables that need to be joined? (Enter at least two tables separated by commas): ")
                intbsplit = re.split(r',\s*', userintb)
            for i, table in enumerate(intbsplit):
                while table not in table2 and table not in dfs_joined:
                    print(table, "is not in the schema! Kindly enter a valid table name: ")
                    table = input()
                    intbsplit[i] = table
            joindf(intbsplit)
            
            choice14 = input("Are there any other tables to be joined? (y/n) ")
            while choice14.lower() not in ['y', 'n', 'yes', 'no']:
                choice14 = input("Please enter your choice correctly! (y/n): ")
            if choice14.lower() in ['n', 'no']:
                print("Session end!")
                break