# Data Validation and Cleansing

In [1]:
# Libraries สำหรับสร้าง SparkSession และ DataFrame
from pyspark.sql import SparkSession
from pyspark.sql.functions import col  # สำหรับการจัดการคอลัมน์

In [2]:
# สร้าง SparkSession พร้อมตั้งค่าการเชื่อมต่อ MinIO
spark = SparkSession.builder \
    .appName("EDA and Cleansing") \
    .config("spark.hadoop.fs.s3a.endpoint", "http://minio:9000") \
    .config("spark.hadoop.fs.s3a.access.key", "admin") \
    .config("spark.hadoop.fs.s3a.secret.key", "password") \
    .config("spark.hadoop.fs.s3a.path.style.access", "true") \
    .config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem") \
    .getOrCreate()


24/11/24 13:34:05 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [3]:
# ระบุ path ของไฟล์ใน MinIO
input_path = "s3a://warehouse/raw/yellow_tripdata_2021-06.parquet"

# อ่านไฟล์ Parquet จาก MinIO
df = spark.read.parquet(input_path)


24/11/24 13:34:05 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties
                                                                                

In [4]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, unix_timestamp, when, count, mean, stddev, expr
from functools import reduce

# Step 1: Remove duplicate rows
df = df.dropDuplicates()

# Step 2: Remove rows with missing values in critical columns
df = df.na.drop(subset=["RateCodeID", "store_and_fwd_flag", "congestion_surcharge", "passenger_count", "airport_fee"])

# Step 3: Filter 'VendorID' where values are 1 or 2
df = df.filter((col("VendorID") == 1) | (col("VendorID") == 2))

# Step 4: Filter 'RateCodeID' with valid values (1-6)
df = df.filter(col("RateCodeID").isin([1, 2, 3, 4, 5, 6]))

# Step 5: Filter 'payment_type' with valid values (1-6)
df = df.filter(col("payment_type").isin([1, 2, 3, 4, 5, 6]))

# Step 6: Filter 'store_and_fwd_flag' with valid values ('Y', 'N')
df = df.filter(col("store_and_fwd_flag").isin(['Y', 'N']))

# Step 7: Filter 'MTA_tax' with value 0.5
df = df.filter(col("mta_tax") == 0.5)

# Step 8: Filter 'Improvement_surcharge' with value 0.3
df = df.filter(col("improvement_surcharge") == 0.3)

# Step 9: Filter 'airport_fee' with values 0 or 1.25
df = df.filter((col("airport_fee") == 0) | (col("airport_fee") == 1.25))

# Step 10: Filter 'fare_amount' > 0 and remove outliers
fare_mean = df.select(mean("fare_amount")).collect()[0][0]
fare_std = df.select(stddev("fare_amount")).collect()[0][0]
lower_bound_fare = fare_mean - 3 * fare_std
upper_bound_fare = fare_mean + 3 * fare_std
df = df.filter((col("fare_amount") > 0) & (col("fare_amount") >= lower_bound_fare) & (col("fare_amount") <= upper_bound_fare))

# Step 11: Replace 'passenger_count' = 0 with median and filter valid values
median_passenger_count = df.filter(col("passenger_count") > 0).approxQuantile("passenger_count", [0.5], 0)[0]
df = df.withColumn("passenger_count", when(col("passenger_count") == 0, median_passenger_count).otherwise(col("passenger_count")))
df = df.filter(col("passenger_count").isin([1, 2, 3, 4, 5, 6]))

# Step 12: Filter 'trip_distance' > 0 and <= 150 miles
df = df.filter((col("trip_distance") > 0) & (col("trip_distance") <= 150))

# Step 13: Create and filter 'trip_duration' > 0 (in minutes)
df = df.withColumn(
    "trip_duration", 
    (unix_timestamp(col("tpep_dropoff_datetime")) - unix_timestamp(col("tpep_pickup_datetime"))) / 60
)
df = df.filter(col("trip_duration") > 0)

# Step 14: Filter 'speed' values > 0 and <= 55 mph
df = df.withColumn("speed", col("trip_distance") / (col("trip_duration") / 60))
df = df.filter((col("speed") > 0) & (col("speed") <= 55))

# Step 15: Filter 'tip_amount' >= 0 and remove outliers
tip_mean = df.select(mean("tip_amount")).collect()[0][0]
tip_std = df.select(stddev("tip_amount")).collect()[0][0]
lower_bound_tip = tip_mean - 3 * tip_std
upper_bound_tip = tip_mean + 3 * tip_std
df = df.filter((col("tip_amount") >= 0) & (col("tip_amount") >= lower_bound_tip) & (col("tip_amount") <= upper_bound_tip))

# Step 16: Filter 'tolls_amount' >= 0
df = df.filter(col("tolls_amount") >= 0)

# Step 17: Filter 'total_amount' >= 0 and remove outliers
total_mean = df.select(mean("total_amount")).collect()[0][0]
total_std = df.select(stddev("total_amount")).collect()[0][0]
lower_bound_total = total_mean - 3 * total_std
upper_bound_total = total_mean + 3 * total_std
df = df.filter((col("total_amount") >= 0) & (col("total_amount") >= lower_bound_total) & (col("total_amount") <= upper_bound_total))

# Step 18: Filter 'congestion_surcharge' >= 0
df = df.filter(col("congestion_surcharge") >= 0)

24/11/24 13:34:16 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:34:16 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:34:24 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:34:24 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:34:24 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:34:24 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:34:24 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:34:24 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:34:24 WARN RowBasedKeyValueBatch: Calling spill() on

In [5]:
# Display the cleansed DataFrame schema and first 10 rows for validation
df.printSchema()
df.show(10)

root
 |-- VendorID: long (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: double (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: double (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- airport_fee: double (nullable = true)
 |-- trip_duration: double (nullable = true)
 |-- speed: double (nullable = true)



24/11/24 13:35:47 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:35:47 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
                                                                                

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+-------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|     trip_duration|              speed|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+-------------------+
|       2| 2021-06-01 00:05:01|  2021-06-01 00:12:49|            1.0| 

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Convert Spark DataFrame to Pandas DataFrame for analysis
df_pandas = df.toPandas()

# Calculate total null values for each column
null_counts = df_pandas.isnull().sum()

# Create a DataFrame for visualization
null_counts_df = pd.DataFrame({
    "Field": null_counts.index,
    "Null Count": null_counts.values
}).sort_values(by="Null Count", ascending=False)

# Plotting a bar chart for total nulls
plt.figure(figsize=(12, 6))
plt.barh(null_counts_df["Field"], null_counts_df["Null Count"], color="orange", alpha=0.8)
plt.title("Total Nulls by Field", fontsize=16)
plt.xlabel("Null Count", fontsize=12)
plt.ylabel("Fields", fontsize=12)

# Add value labels on bars
for index, value in enumerate(null_counts_df["Null Count"]):
    plt.text(value, index, str(value), va="center", ha="left", fontsize=10, color="black")

plt.tight_layout()
plt.show()


24/11/24 13:37:53 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:37:53 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:37:59 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:37:59 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:37:59 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:37:59 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:37:59 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:37:59 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
24/11/24 13:37:59 WARN RowBasedKeyValueBatch: Calling spill() on