In [1]:
from pyspark.sql import SparkSession

#  Using this function as spark instance
def spark_inst():
    return SparkSession.builder.master("local[*]")\
           .appName('Spark')\
           .getOrCreate()

In [None]:
#extract.py
# Read data from mysql database
def extract(spark: SparkSession, type: str, source: str):
# Read data from mysql database
    if type=="JDBC":
        output_df = spark.read.format("JDBC").options(url='jdbc:mysql://localhost/world',dbtable=source,driver='com.mysql.cj.jdbc.Driver',user='root',password='root').load()
        return output_df
    # read data from filesystem
    if type=="CSV":
        output_df = spark.read.format("CSV").options(header=True,inferSchema=True).load(source)
        return output_df

In [None]:
# constant.py
#Tranform the column by renaming them by using a dictionary. 
CITY_COL_DICT={
     "ID": "city_id",
     "Name": "city_name",
     "CountryCode": "country_code",
     "District": "city_district",
     "Population": "city_population"
}
COUNTRY_COL_DICT={
     "Code": "country_code",
     "Name": "country_name",
     "Continent": "continent",
     "Region": "region",
     "SurfaceArea": "surface_area",
     "IndepYear": "independence_year",
     "Population": "country_population",
     "LifeExpectancy": "life_expectancy",
     "GNP": "gross_national_product",
     "GNPOld": "old_gross_national_product",
     "LocalName": "local_name",
     "GovernmentForm": "government_form",
     "HeadOfState": "head_of_state",
     "Capital": "capital",
     "Code2": "country_code_2"
}
COUNTRY_LANGUAGE_COL_DICT={
     "CountryCode": "country_code",
     "Language": "language",
     "IsOfficial": "is_official_language",
     "Percentage": "language_percentage"
}

In [None]:
#Do a join operation on country code since its common in all dataframes
# constant.py

JOIN_ON_COLUMNS=['country_code']
JOIN_TYPE="left"
SPEC_COLS=[
     "country_code",
     "country_name",
     "region",
     "surface_area",
     "independence_year",
     "country_population",
     "life_expectancy",
     "local_name",
     "head_of_state",
     "capital",
     "country_code_2",
     "city_id",
     "city_name",
     "city_district",
     "city_population",
     "language",
     "is_official_language",
     "language_percentage"
]

In [None]:
# load.py
#Load the data in mysql and filesystem

from pyspark.sql import DataFrame


def load(type: str, df: DataFrame, target: str):
    # Load the data based on type
    '''
    :param type: Input Storage type (JDBC|CSV) Based on type data stored in MySQL or FileSystem
    :param df: Input Dataframe
    :param target: Input target 
             -For filesystem - Location where to store the data
             -For MySQL - table name
    '''

    # Write data on mysql database with table name
    if type=="JDBC":
       df.write.format("JDBC").mode("overwrite")\
.options(url='jdbc:mysql://localhost/world',dbtable=target,driver='com.mysql.cj.jdbc.Driver',user='root',password='root').save()
       print(f"Data succesfully loaded to MySQL Database !!")
    
    if type=="CSV":
    # Write data on filesystem
       df.write.format("CSV").mode("overwrite").options(header=True).save(target)
       print(f"Data succesfully loaded to filesystem !!")

In [None]:
#To execute everything in ETL steps put all the functions together
# Imported required libraries and modules
from pyspark.sql import SparkSession
from metadata.constant import CITY_COL_DICT, COUNTRY_LANGUAGE_COL_DICT, COUNTRY_COL_DICT, \
JOIN_TYPE,JOIN_ON_COLUMNS, SPEC_COLS, spark_inst
from extract import extract
from transform import rename_cols, join_df, specific_cols
from load import load

# Initiating and Calling SparkSession
SPARK=spark_inst()

#### Extract ####

# Extracting CITY and COUNTRY data from MYSQL
city_df = extract(SPARK,"JDBC","city")
country_df = extract(SPARK,"JDBC","country")

# Extracting COUNTRYLANGUAGE data from FileSystem
country_language_df = extract(SPARK,"CSV","filesystem/countrylanguage.csv")

#### Transformation ####

# 1. Rename Columns
city_df = rename_cols(city_df, CITY_COL_DICT)
country_df = rename_cols(country_df, COUNTRY_COL_DICT)
country_language_df = rename_cols(country_language_df, COUNTRY_LANGUAGE_COL_DICT)

# 2. Join DF with common column "country_code"
country_city_df=join_df(country_df, city_df, JOIN_ON_COLUMNS, JOIN_TYPE)
country_city_language_df= join_df(country_city_df, country_language_df, JOIN_ON_COLUMNS, JOIN_TYPE)

# 3. Get specific cols
country_city_language_df = specific_cols(country_city_language_df, SPEC_COLS)


#### Load Data ####

# MySQL
load("JDBC",country_city_language_df, "CountryCityLanguage")

# FileSystem
load("CSV",country_city_language_df, "output/countrycitylanguage.csv")