In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import to_date, date_format, trim, concat, col, lit, datediff, when, regexp_replace, regexp_extract
import psycopg2
import urllib.parse

In [2]:
spark = SparkSession.Builder().appName("ETL_Assignment").master("local[*]").config("spark.jars", "postgresql-42.7.4.jar").getOrCreate()

24/11/02 13:09:40 WARN Utils: Your hostname, Nikhils-MacBook-Air.local resolves to a loopback address: 127.0.0.1; using 192.168.0.101 instead (on interface en0)
24/11/02 13:09:40 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/11/02 13:09:40 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [3]:
def read_csv(filePath):
    """
    This function will read data from the arg filePath (which should be a string) and will return the data as a PySpark data frame.
    Since I don't have expertise working with Padas, I have used PySpark.
    I can do the same using Pandas, but I will need the help of the GenAI tool, and at this moment, I don't want to do this.
    """
    try:
        employee_df = spark.read.option("header", True).csv(filePath)
        return employee_df
    except:
        print("Kuch to gadbad hai Daya, Data read nahi ho raha hai")
        return None

In [43]:
filePath = "employee_details.csv"
read = read_csv(filePath)
read.show()

                                                                                

+----------+---------+--------------+----------+-----------+-------+
|EmployeeID|FirstName|    LastName  | BirthDate| Department| Salary|
+----------+---------+--------------+----------+-----------+-------+
|  E001    |  Alice  |     White    |1990-06-12|  Finance  |  55000|
|  E002    |    Bob  |     Brown    |1988-01-03| IT        |  90000|
|  E003    |   Carol |     Grey     |1995-07-15| HR        |  47000|
|  E004    |   David |     Black    |1992-09-23| Marketing |  75000|
|  E005    |   Eve   |     Green    |1985-03-18| IT        | 125000|
|  E006    |   Frank |     Turner   |1980-11-21| Finance   |  85000|
|  E007    |   Grace |     Harper   |1993-02-15| HR        |  62000|
|  E008    |    Holly|     Norman   |1999-04-09| Marketing |  43000|
|  E009    |     Ian |     Ross     |1996-05-27| Finance   | 110000|
|  E010    |   John  |     Bishop   |1997-08-18| IT        |  70000|
|EmployeeID|FirstName|    LastName  | BirthDate| Department| Salary|
|  E011    |    Ada  |    L Lovela

In [44]:
def transform_data(df):
    """
    This function transforms the input DataFrame by performing various cleaning and transformation operations as per requirement.

    This function will require a DF as an argument and will return a df.
    """

    # Convert BirthDate to Date type and filter out nulls
    df = (df.withColumn("BirthDate", to_date(col("BirthDate"), "yyyy-MM-dd"))
            .filter(col("BirthDate").isNotNull()))


    # Format BirthDate to DD/MM/YYYY
    df = df.withColumn("birthdate", date_format(col("BirthDate"), "dd/MM/yyyy"))

    # Extract digits from Salary (remove "-") and convert to Integer for further transformations
    df = (df.withColumn("Salary", regexp_extract(col("Salary"), "\\d+", 0))
            .filter(col("Salary") != "")
            .withColumn("Salary", col("Salary").cast("int")))


    # Trim spaces and remove special characters from FirstName and LastName
    column = df.columns
    df = (df.withColumn("FirstName", trim(regexp_replace(column[1], "[^A-Za-z\\s]", "")))
            .withColumn("LastName", trim(regexp_replace(column[2], "[^A-Za-z\\s]", ""))))

    # FullName column by concatenating FirstName and LastName
    df = df.withColumn("fullName", concat(col("FirstName"), lit(" "), col("LastName")))

    # Age with reference to 01-01-2023
    df = df.withColumn("Age", 
                       (datediff(to_date(lit("2023-01-01"), "yyyy-MM-dd"), 
                                 to_date(col("BirthDate"), "dd/MM/yyyy")) / 365).cast("int"))

    #SalaryBucket based on Salary
    df = (df.withColumn("SalaryBucket", 
                        when(col("Salary") < 50000, "A")
                        .when((col("Salary") >= 50000) & (col("Salary") < 100000), "B")
                        .otherwise("C")))
    df = df.drop(column[1], column[2],"LastName")

    return df

In [45]:
transformed_df = transform_data(read)
transformed_df.show()

+----------+----------+-----------+------+-------------------+---+------------+
|EmployeeID| birthdate| Department|Salary|           fullName|Age|SalaryBucket|
+----------+----------+-----------+------+-------------------+---+------------+
|  E001    |12/06/1990|  Finance  | 55000|        Alice White| 32|           B|
|  E002    |03/01/1988| IT        | 90000|          Bob Brown| 35|           B|
|  E003    |15/07/1995| HR        | 47000|         Carol Grey| 27|           A|
|  E004    |23/09/1992| Marketing | 75000|        David Black| 30|           B|
|  E005    |18/03/1985| IT        |125000|          Eve Green| 37|           C|
|  E006    |21/11/1980| Finance   | 85000|       Frank Turner| 42|           B|
|  E007    |15/02/1993| HR        | 62000|       Grace Harper| 29|           B|
|  E008    |09/04/1999| Marketing | 43000|       Holly Norman| 23|           A|
|  E009    |27/05/1996| Finance   |110000|           Ian Ross| 26|           C|
|  E010    |18/08/1997| IT        | 7000

In [38]:
def load_data(df, table_name: str):
    """
    Loads transformed DataFrame into a PostgreSQL database table and creates indexes.
    Note: The url for Postgre is hardcoded. I have created a test Postgre server for this. 

    Parameters:
    df (DataFrame): The DataFrame to load into the database.
    table_name (str): The name of the target table in the database.
    """

    # Convert the table name to lowercase to avoid case issues
    table_name = table_name.lower()

    # PostgreSQL connection details
    db_url = 'jdbc:postgresql://pg-nikhiltest-nnn-nikhiltest.h.aivencloud.com:17081/defaultdb?sslmode=require'
    
    # Extract connection parameters for psycopg2 connection
    url_parts = urllib.parse.urlparse('postgres://avnadmin:AVNS_yGpIBK9-f-QpZ5bA7ri@pg-nikhiltest-nnn-nikhiltest.h.aivencloud.com:17081/defaultdb?sslmode=require')
    dbname = url_parts.path[1:]  # Get database name from db_url (removing the leading '/')
    host = url_parts.hostname
    port = url_parts.port
    user = url_parts.username
    password = url_parts.password

    # Create a connection to PostgreSQL
    try:
        conn = psycopg2.connect(dbname=dbname, user=user, password=password, host=host, port=port)
    except Exception as e:
        print(f"Failed to connect to the database: {e}")
        return

    try:
        # Drop the table if it exists
        drop_table_query = f"DROP TABLE IF EXISTS {table_name};"
        with conn.cursor() as cursor:
            cursor.execute(drop_table_query)
            conn.commit()
            print(f"Dropped table: {table_name}")

        # Create the table with lowercase naming
        create_table_query = f"""
        CREATE TABLE {table_name} (
            employeeid SERIAL PRIMARY KEY,
            birthdate DATE NOT NULL,
            department VARCHAR(100) NOT NULL,
            salary NUMERIC(10, 2) NOT NULL,
            fullname VARCHAR(200) NOT NULL,
            age INT NOT NULL,
            salarybucket VARCHAR(50) NOT NULL
        );
        """
        
        with conn.cursor() as cursor:
            cursor.execute(create_table_query)
            conn.commit()
            print(f"Created table: {table_name}")

        # Check if table was created
        check_table_query = f"""
        SELECT EXISTS (
            SELECT FROM information_schema.tables 
            WHERE table_name = '{table_name}'
        );
        """
        with conn.cursor() as cursor:
            cursor.execute(check_table_query)
            table_exists = cursor.fetchone()[0]
            if table_exists:
                print(f"Table '{table_name}' was successfully created.")
            else:
                print(f"Table '{table_name}' does not exist after creation attempt.")

        # Convert DataFrame column names to lowercase
        df = df.toDF(*[col.lower() for col in df.columns])

        # Print DataFrame schema for debugging
        print("DataFrame schema:")
        df.printSchema()

        # Check if required columns exist
        required_columns = ['birthdate', 'department', 'salary', 'fullname', 'age', 'salarybucket']
        missing_columns = [col for col in required_columns if col not in df.columns]
        if missing_columns:
            print(f"Missing columns in DataFrame: {missing_columns}")
            return

        # Write the DataFrame directly to PostgreSQL using Spark's DataFrameWriter
        df.write \
            .format("jdbc") \
            .option("url", db_url) \
            .option("dbtable", table_name) \
            .option("user", user) \
            .option("password", password) \
            .option("driver", "org.postgresql.Driver") \
            .mode("overwrite") \
            .save()

        # Create indexes for enhancing retrieval performance
        index_queries = [
            f"CREATE INDEX IF NOT EXISTS idx_fullname ON {table_name} (fullname);",
            f"CREATE INDEX IF NOT EXISTS idx_age ON {table_name} (age);",
            f"CREATE INDEX IF NOT EXISTS idx_salary ON {table_name} (salary);",
            f"CREATE INDEX IF NOT EXISTS idx_salarybucket ON {table_name} (salarybucket);"
        ]
        
        # Execute each index creation query
        for query in index_queries:
            with conn.cursor() as cursor:
                cursor.execute(query)
                conn.commit()
                print(f"Created index with query: {query}")

        print(f"Data successfully loaded into {table_name} and indexes created.")

    except Exception as e:
        print(f"Error occurred: {e}")

    finally:
        # Ensure the connection is closed
        conn.close()

In [41]:
#Execution code

filePath = "employee_details.csv"
read = read_csv(filePath)
#read.show()


transformed_df = transform_data(read)


load_data(transformed_df, "employees")


                                                                                

Dropped table if it existed: employees
Created table: employees
Table 'employees' was successfully created.
DataFrame schema:
root
 |-- employeeid: string (nullable = true)
 |-- birthdate: string (nullable = true)
 |-- department: string (nullable = true)
 |-- salary: integer (nullable = true)
 |-- fullname: string (nullable = true)
 |-- age: integer (nullable = true)
 |-- salarybucket: string (nullable = false)



                                                                                

Created index with query: CREATE INDEX IF NOT EXISTS idx_fullname ON employees (fullname);
Created index with query: CREATE INDEX IF NOT EXISTS idx_age ON employees (age);
Created index with query: CREATE INDEX IF NOT EXISTS idx_salary ON employees (salary);
Created index with query: CREATE INDEX IF NOT EXISTS idx_salarybucket ON employees (salarybucket);
Data successfully loaded into employees and indexes created.
