# Notebook for EDA and testing pyspark script

The idea of this notebook is to explore the data and the total_amount column to detect potential data anomalies to fix.

Also the idea would be to continue doing a more extensive analysis to understand why there are incorrect numbers and check if there are trends worth exploring to fix the data problems in the future

*I had to run the code in a Colab notebook due to JVM errors with my personal computer*

## Create Spark Session

In [73]:
from pyspark.sql import SparkSession

# Creating SparkSession
spark = SparkSession \
        .builder \
        .appName('TLC_Trip_Records_Service') \
        .getOrCreate()

## Reading Files

In [8]:
download_folder = 'sample_data'
parquet_df = spark.read.parquet(download_folder)

## Filter top 10%

In [9]:
from pyspark.sql.functions import col, round, format_number
print(parquet_df.count())
top_10_num = int(parquet_df.count()*0.1)
print(top_10_num)
filtered_df = parquet_df.orderBy(col('trip_distance').desc()) \
                        .limit(top_10_num)
filtered_df = filtered_df.withColumn('total_amount',
                                     format_number(col('total_amount'), 2)
                                     )
filtered_df = filtered_df.withColumn('total_amount_sum',
                                      format_number(
                                          round(col('fare_amount')+col('extra')
                                            +col('tolls_amount')+col('improvement_surcharge')+col('mta_tax')
                                     , 2)
                                     , 2)
                                     )
filtered_df.show()

9071244
907124
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|total_amount_sum|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+
|       2| 2022-02-15 18:24:00|  2022-02-15 18:37:00|           NULL|    348798.53|      NULL|              NULL|        

## Check for null values in total_sum

In [10]:


null_count = filtered_df.where(col('total_amount').isNull()).count()
print(f'Number of null values: {null_count}')


Number of null values: 0


Notes:

not nulls for total_amount

## Check when total_amount is 0

In [11]:
zero_count = filtered_df.where(col('total_amount') == 0.0).count()
zero_df = filtered_df.where(col('total_amount') == 0.0)
print(f'Number of Records with 0 total_amount: {zero_count}')

Number of Records with 0 total amount: 17


In [12]:
zero_df.show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|total_amount_sum|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+
|       1| 2022-02-11 16:51:47|  2022-02-11 18:10:38|            1.0|         32.9|       1.0|                 N|         132|         2

In [13]:
print('distinct elements for fare_amount: '+ str(zero_df.select(col('fare_amount')).distinct().rdd.flatMap(lambda x: x).collect()))
print('distinct elements for extra: '+ str(zero_df.select(col('extra')).distinct().rdd.flatMap(lambda x: x).collect()))
print('distinct elements for mta_tax: '+ str(zero_df.select(col('mta_tax')).distinct().rdd.flatMap(lambda x: x).collect()))
print('distinct elements for tip_amount: '+ str(zero_df.select(col('tip_amount')).distinct().rdd.flatMap(lambda x: x).collect()))
print('distinct elements for tolls_amount: '+ str(zero_df.select(col('tolls_amount')).distinct().rdd.flatMap(lambda x: x).collect()))
print('distinct elements for improvement_surcharge: '+ str(zero_df.select(col('improvement_surcharge')).distinct().rdd.flatMap(lambda x: x).collect()))
print('distinct elements for congestion_surcharge: '+ str(zero_df.select(col('congestion_surcharge')).distinct().rdd.flatMap(lambda x: x).collect()))
print('distinct elements for airport_fee: '+ str(zero_df.select(col('airport_fee')).distinct().rdd.flatMap(lambda x: x).collect()))

distinct elements for fare_amount: [0.0]
distinct elements for extra: [0.0]
distinct elements for mta_tax: [0.0]
distinct elements for tip_amount: [0.0]
distinct elements for tolls_amount: [0.0]
distinct elements for improvement_surcharge: [0.0]
distinct elements for congestion_surcharge: [0.0, 2.5, None]
distinct elements for airport_fee: [0.0, 1.25, None]


### Filter the rows that are Null for airport_fee and congestion_surcharge




In [14]:
zero_df.filter(
    (col('congestion_surcharge').isNull())
).show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|total_amount_sum|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+
|       1| 2022-02-14 07:00:55|  2022-02-14 07:00:58|           NULL|         20.6|      NULL|              NULL|         186|         1

Notes:


*   It does include trip_distance, only 1 case is 0 distance. (worth checking how many trip distance 0 there are)
*   payment type is always 0 and doesn't exists (worth checking)
*   It does include trip_distance
*   not data if it is a store and forward trip
*   VendorID: [1]



### Filter the rows that are 2.5 or 1.25


In [15]:
zero_df.filter(
    (col('congestion_surcharge') == 2.5) | (col('congestion_surcharge') ==  1.25)
).distinct().show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|total_amount_sum|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+
|       2| 2022-03-28 19:27:18|  2022-03-28 19:43:56|            1.0|         8.24|       1.0|                 N|         138|         2

Notes:


*   only 1 record
*   many trips with distance 0
*   payment type: [1]
*   not a store and forward trip
*   VendorID: [2]


In [16]:
zero_df.filter(
    (col('congestion_surcharge') == 2.5)
).distinct().count()

1

In [17]:
from pyspark.sql.functions import when
zero_df = zero_df.withColumn('is_sum_correct',
                                      when(col('total_amount') == col('total_amount_sum'), True).otherwise(False)
                                      )

## Check when total_amount is below 0

In [18]:
below_zero_count = filtered_df.where(col('total_amount') < 0).count()
below_zero_df = filtered_df.where(col('total_amount') < 0)
print(f'Number of records below 0: {below_zero_count}')


Number of records below 0: 3185


In [19]:
below_zero_df.show(50)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|total_amount_sum|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+
|       2| 2022-03-30 16:39:28|  2022-03-30 19:07:15|            1.0|       181.21|       5.0|                 N|         265|         1

## filter by vendor and payment type

In [21]:
print('distinct elements for airport_fee: '+ str(filtered_df.select(col('total_amount')).distinct().rdd.flatMap(lambda x: x).collect()))

distinct elements for airport_fee: ['27.90', '24.44', '16.09', '22.20', '25.12', '18.85', '14.67', '20.19', '15.56', '47.47', '29.55', '46.72', '24.87', '14.35', '25.47', '39.78', '39.87', '17.35', '37.71', '45.38', '17.81', '41.15', '69.90', '18.00', '15.58', '77.30', '22.74', '15.45', '21.75', '23.07', '41.97', '29.05', '26.68', '19.00', '25.27', '41.37', '17.42', '21.40', '38.07', '36.10', '35.40', '44.97', '36.37', '19.61', '24.27', '17.89', '26.81', '73.81', '18.50', '76.59', '14.50', '25.63', '48.06', '19.08', '33.00', '8.50', '19.33', '39.66', '14.83', '21.26', '49.21', '20.93', '10.21', '18.51', '25.66', '13.33', '34.17', '20.67', '47.24', '19.43', '11.61', '29.26', '103.72', '29.93', '19.40', '15.00', '19.02', '61.77', '45.14', '19.13', '42.23', '13.28', '26.33', '29.28', '14.24', '16.21', '43.19', '24.35', '13.84', '16.97', '18.01', '31.06', '27.47', '21.25', '47.11', '17.10', '25.71', '33.88', '38.72', '27.86', '16.10', '21.47', '22.81', '20.00', '18.33', '37.63', '32.97', '

## Adding total_amount_sum column

In [23]:
filtered_df.select(col('total_amount'),col('total_amount_sum')).orderBy(col('total_amount').desc()).show()

+------------+----------------+
|total_amount|total_amount_sum|
+------------+----------------+
|       99.97|           99.97|
|       99.96|           83.30|
|       99.96|           80.80|
|       99.96|           80.80|
|       99.96|           83.30|
|       99.96|           81.97|
|       99.96|           83.30|
|       99.95|           99.95|
|       99.95|           83.30|
|       99.95|           89.95|
|       99.95|           91.95|
|       99.94|           91.31|
|       99.92|           90.30|
|       99.90|           82.00|
|       99.87|           83.25|
|       99.85|           96.10|
|       99.85|           98.60|
|       99.85|           98.60|
|       99.84|           99.84|
|       99.81|           78.85|
+------------+----------------+
only showing top 20 rows



## Check when Sum is correct or not

In [24]:
total_sum_df = filtered_df.withColumn('is_sum_correct',
                                      when(col('total_amount') == col('total_amount_sum'), True).otherwise(False)
                                      )
total_sum_df.select(col('total_amount'),col('total_amount_sum'),col('is_sum_correct')).orderBy(col('total_amount').desc()).show()


+------------+----------------+--------------+
|total_amount|total_amount_sum|is_sum_correct|
+------------+----------------+--------------+
|       99.97|           99.97|          true|
|       99.96|           83.30|         false|
|       99.96|           80.80|         false|
|       99.96|           80.80|         false|
|       99.96|           83.30|         false|
|       99.96|           81.97|         false|
|       99.96|           83.30|         false|
|       99.95|           99.95|          true|
|       99.95|           83.30|         false|
|       99.95|           89.95|         false|
|       99.95|           91.95|         false|
|       99.94|           91.31|         false|
|       99.92|           90.30|         false|
|       99.90|           82.00|         false|
|       99.87|           83.25|         false|
|       99.85|           96.10|         false|
|       99.85|           98.60|         false|
|       99.85|           98.60|         false|
|       99.84

# Checking the correct ones

In [25]:
total_sum_df_correct = total_sum_df.filter(
                                    col('total_amount') == col('total_amount_sum')
                                    )

print(total_sum_df_correct.count())
total_sum_df_correct.show()


9833
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+--------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|total_amount_sum|is_sum_correct|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+--------------+
|       2| 2022-01-26 15:26:00|  2022-01-26 15:27:00|           NULL|    123474.27|   

## Checking the incorrect ones

In [27]:
total_sum_df.cache()

DataFrame[VendorID: bigint, tpep_pickup_datetime: timestamp_ntz, tpep_dropoff_datetime: timestamp_ntz, passenger_count: double, trip_distance: double, RatecodeID: double, store_and_fwd_flag: string, PULocationID: bigint, DOLocationID: bigint, payment_type: bigint, fare_amount: double, extra: double, mta_tax: double, tip_amount: double, tolls_amount: double, improvement_surcharge: double, total_amount: string, congestion_surcharge: double, airport_fee: double, total_amount_sum: string, is_sum_correct: boolean]

In [28]:
total_sum_df_incorrect = total_sum_df.filter(
                                    col('total_amount') != col('total_amount_sum')
                                    )

print(total_sum_df_incorrect.count())
total_sum_df_incorrect.show()

80879
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+--------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|total_amount_sum|is_sum_correct|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+--------------+
|       2| 2022-02-15 18:24:00|  2022-02-15 18:37:00|           NULL|    348798.53|  

In [31]:
total_sum_df_incorrect_zero_amount = total_sum_df_incorrect.filter(
                                    col('total_amount') < 0
                                    )

print(total_sum_df_incorrect_zero_amount.count())
total_sum_df_incorrect_zero_amount.show()

341
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+--------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|total_amount_sum|is_sum_correct|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+--------------+
|       2| 2022-03-30 16:39:28|  2022-03-30 19:07:15|            1.0|       181.21|    

In [44]:
test_df = total_sum_df_incorrect_zero_amount.withColumn('total_amount2',
                                     when(col('total_amount') != col('total_amount_sum'),
                                          format_number(col('total_amount_sum')*-1, 2)
                                          )\
                                          .otherwise(
                                              format_number(col('total_amount')*-1, 2))

                                     )

print(test_df.count())
test_df.show(200)

341
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+--------------+-------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|total_amount_sum|is_sum_correct|total_amount2|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+--------------+-------------+
|       2| 2022-03-30 16:39:28|  2022-03-30 1

In [45]:

total_amount_not_null_df = test_df.withColumn('total_amount',
                                     col('total_amount').cast('float').na.fill(0.0, ['total_amount'])
                                     )

TypeError: ignored

In [47]:
test_df = total_sum_df_incorrect_zero_amount.withColumn("total_amount", col("total_amount").cast("float")).na.fill(0.0, ["total_amount"])
test_df.show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+--------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|total_amount_sum|is_sum_correct|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+----------------+--------------+
|       2| 2022-03-30 16:39:28|  2022-03-30 19:07:15|            1.0|       181.21|       5

# Main pyspark Script

In [80]:
# Creating SparkSession
import os
spark = SparkSession \
        .builder \
        .appName('TLC_Trip_Records_Service') \
        .getOrCreate()

# Reading Parquet files to a dataframe
input_folder = 'sample_data'
df = spark.read.parquet(input_folder)

# Count the number of rows in the DataFrame
records_count = df.count()
print(f'Total number of TLC records: {records_count}')

# Filtering top 10% trips
## This variable can be modified
top_n = 0.1
top_10_num = int(df.count()*top_n)
print(f'Now filtered top {top_n*100}% of TLC records. Number of records :{top_10_num}')
# Order elements and get top N %
top_n_df = df.orderBy(col('trip_distance').desc()) \
                        .limit(top_10_num)

# Replace None values with 0.0
total_amount_not_null_df = top_n_df.withColumn('total_amount', col('total_amount').cast('float'))
total_amount_not_null_df = total_amount_not_null_df.na.fill(0.0, ['total_amount'])
print('None values replaced.')
# Format total_amount numbers to have 2 decimals
top_n_df = top_n_df.withColumn('total_amount',
                                     format_number(col('total_amount'), 2)
                                     )
print('Formatted total_amount with 2 decimals.')

# Calculating total_amount based on values in relevant columns for the total sum.
# Does not include cash tips as in the data Dictionary is specified this way
top_n_df = top_n_df.withColumn('total_amount_sum',
                                      format_number(
                                          round(col('fare_amount')+col('extra')
                                            +col('tolls_amount')+col('improvement_surcharge')+col('mta_tax')
                                     , 2)
                                     , 2)
                                     )
print('total_amount correction column calculated.')
# Add column to check if calculation is correct
total_sum_df = top_n_df.withColumn('is_sum_correct',
                                      when(col('total_amount') == col('total_amount_sum'), True).otherwise(False)
                                      )
print('total_amount validation column created.')
# Converting total_amount column to the positive sum of the variables for the formula to calculate total_amount

formated_total_amount_df = total_sum_df.withColumn('total_amount',
                                     when(col('total_amount') != col('total_amount_sum'),
                                          col('total_amount_sum')*-1
                                          )\
                                          .otherwise(col('total_amount')*-1)
                                     )
print('total_amount values changed to positive.')
# Updating total_amount column

updated_df = formated_total_amount_df.withColumn('total_amount',col('total_amount_sum'))
updated_df = updated_df.drop(*['total_amount_sum','is_sum_correct'])
print('total_amount new values updated accordingly.')

# Create folder to save parquet file
## This variable can be modified and moved to a config file
output_folder = "/output/top_10_yellow_line_TLC_records.parquet"
if not os.path.exists(output_folder):
    os.mkdir(output_folder)
    print(f'folder {output_folder} created.')

# Save the DataFrame as a Parquet file
updated_df.write\
          .mode("overwrite")\
          .parquet(output_folder)
print('Parquet file saved.')

# Display output dataframe
updated_df.show()

# Stop spark
spark.stop()
print('SparkSession over.')


Total number of TLC records: 9071244
Now filtered top 10.0% of TLC records. Number of records :907124
None values replaced.
Formatted total_amount with 2 decimals.
total_amount correction column calculated.
total_amount validation column created.
total_amount values changed to positive.
total_amount new values updated accordingly.
Parquet file saved.
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+-----------

In [None]:
import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, round, format_number


# Functions


# Filter top N records
def filter_top_n(df, top_n):
    top_10_num = int(df.count() * top_n)
    top_n_df = df.orderBy(col('trip_distance').desc()).limit(top_10_num)
    return top_n_df

# Replace None values
def replace_none_values(df):
    total_amount_not_null_df = df.withColumn('total_amount', col('total_amount').cast('float'))
    total_amount_not_null_df = total_amount_not_null_df.na.fill(0.0, ['total_amount'])
    return total_amount_not_null_df

# Total_amount formatter
def format_total_amount(df):
    top_n_df = df.withColumn('total_amount', format_number(col('total_amount'), 2))
    return top_n_df

# total_amount calculator
def calculate_total_amount(df):
    top_n_df = df.withColumn('total_amount_sum', format_number(round(
        col('fare_amount') + col('extra') + col('tolls_amount') + col('improvement_surcharge') + col('mta_tax'), 2), 2))
    return top_n_df

# total_amount validator
def add_total_amount_validation_column(df):
    total_sum_df = df.withColumn('is_sum_correct',
                                when(col('total_amount') == col('total_amount_sum'), True).otherwise(False))
    return total_sum_df

# total_amount negative to positive number converter
def convert_total_amount_to_positive(df):
    formated_total_amount_df = df.withColumn('total_amount', when(col('total_amount') != col('total_amount_sum'),
                                                                 col('total_amount_sum') * -1).otherwise(col('total_amount') * -1))
    return formated_total_amount_df

# Delete irrelevant columns
def update_total_amount_column(df):
    updated_df = df.withColumn('total_amount', col('total_amount_sum')).drop(*['total_amount_sum', 'is_sum_correct'])
    return updated_df

In [81]:

import os
from pyspark.sql import SparkSession
# from utils.pyspark_utils import *
# Creating SparkSession
spark = SparkSession \
        .builder \
        .appName('TLC_Trip_Records_Service') \
        .getOrCreate()
input_folder = 'raw_data'

# Reading Parquet files to a dataframe
input_folder = 'sample_data'
df = spark.read.parquet(input_folder)
# Count the number of rows in the DataFrame
records_count = df.count()
print(f'Total number of TLC records: {records_count}')

# Filtering top 10% trips
## This variable can be modified
top_n = 0.1
top_10_num = int(df.count()*top_n)
print(f'Now filtered top {top_n*100}% of TLC records. Number of records :{top_10_num}')

# Order elements and get top N %
top_n_df = filter_top_n(df, top_n)

# Replace None values with 0.0
total_amount_not_null_df = replace_none_values(top_n_df)
print('None values replaced.')

# Format total_amount numbers to have 2 decimals
total_amount_not_null_df = format_total_amount(total_amount_not_null_df)
print('Formatted total_amount with 2 decimals.')

# Calculating total_amount based on values in relevant columns for the total sum.
# Does not include cash tips as in the data Dictionary is specified this way
total_amount_not_null_df = calculate_total_amount(total_amount_not_null_df)
print('total_amount correction column calculated.')

# Add column to check if calculation is correct
total_amount_not_null_df = add_total_amount_validation_column(total_amount_not_null_df)
print('total_amount validation column created.')

# Converting total_amount column to the positive sum of the variables for the formula to calculate total_amount
total_amount_not_null_df = convert_total_amount_to_positive(total_amount_not_null_df)
print('total_amount values changed to positive.')

# Updating total_amount column
updated_df = update_total_amount_column(total_amount_not_null_df)
print('total_amount new values updated accordingly.')

# Create folder to save parquet file
## This variable can be modified and moved to a config file
output_folder = '/output/top_10_yellow_line_TLC_records.parquet'
if not os.path.exists(output_folder):
    os.mkdir(output_folder)
    print(f'folder {output_folder} created.')

# Save the DataFrame as a Parquet file
updated_df.write\
          .mode("overwrite")\
          .parquet(output_folder)
print('Parquet file saved.')


# Display output dataframe
updated_df.show()

# Stop spark
spark.stop()
print('SparkSession over.')

Total number of TLC records: 9071244
Now filtered top 10.0% of TLC records. Number of records :907124
None values replaced.
Formatted total_amount with 2 decimals.
total_amount correction column calculated.
total_amount validation column created.
total_amount values changed to positive.
total_amount new values updated accordingly.
Parquet file saved.
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+-----------