# Visualizing Big Data

### Read-in and preprocess data

In [None]:
import sys
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.sql.functions import col, sum as _sum

DATA_FILE = r'/Users/michaelkwok/Documents/BU MET/6 Fall2025/CS777/hw3/taxi-data-sorted-small.csv.bz2'
OUTPUT_DIR = r'/Users/michaelkwok/Documents/BU MET/6 Fall2025/CS777/hw5/output/' # make sure this folder exists

def save_string_to_txt(text, output_folder):
    df = spark.createDataFrame([(text,)], ["content"])
    df.coalesce(1).write.mode("overwrite").text(f'{OUTPUT_DIR}/{output_folder}')

# ============================================================================
# Read Data into Spark DataFrame
# ============================================================================

spark = SparkSession.builder.appName("TaxiSparkCloud").getOrCreate()
lines = spark.sparkContext.textFile(DATA_FILE)
taxiLines = lines.map(lambda x: x.split(','))

# read in dataset (source: https://chriswhong.com/open-data/foil_nyc_taxi/)
schema = StructType([
    StructField("medallion", StringType(), True),
    StructField("hack_license", StringType(), True),
    StructField("pickup_datetime", StringType(), True),
    StructField("dropoff_datetime", StringType(), True),
    StructField("trip_time_in_secs", StringType(), True),
    StructField("trip_distance", StringType(), True),
    StructField("pickup_longitude", StringType(), True),
    StructField("pickup_latidue", StringType(), True),  # Note: keeping the typo "latidue"
    StructField("dropoff_longitude", StringType(), True),
    StructField("dropoff_latitude", StringType(), True),
    StructField("payment_type", StringType(), True),
    StructField("fare_amount", StringType(), True),
    StructField("surcharge", StringType(), True),
    StructField("mta_tax", StringType(), True),
    StructField("tip_amount", StringType(), True),
    StructField("tolls_amount", StringType(), True),
    StructField("total_amount", StringType(), True)
])

# create spark dataframe
df = spark.createDataFrame(taxiLines, schema)

# convert appropriate columns to numeric
cast_data_types = {
    "trip_time_in_secs": "double",
    "trip_distance": "double",
    "pickup_longitude": "double",
    "pickup_latidue": "double",
    "dropoff_longitude": "double",
    "dropoff_latitude": "double",
    "fare_amount": "double",
    "surcharge": "double",
    "mta_tax": "double",
    "tip_amount": "double",
    "tolls_amount": "double",
    "total_amount": "double"
}

for column, type_ in cast_data_types.items():
    df = df.withColumn(column, col(column).cast(type_))


# ============================================================================
# Data Preprocessing
# ============================================================================

# exception handling and removing wrong data lines
def isFloat(value):
    try:
        float(value)
        return True
    except:
        return False
    
# remove lines if they don't have 16 values
def correctRows(p):
    if(len(p) == 17):
        if(isFloat(p[5]) and isFloat(p[11])):
            if(float(p[5]) != 0 and float(p[11]) != 0):
                return p
            
# cleaning up data
taxiLinesCorrected = taxiLines.filter(correctRows)

# assert that a valid taxi trip duration is between 2 minutes and 1 hour
df = df.filter(col('trip_time_in_secs')>=120)
df = df.filter(col('trip_time_in_secs')<=3600)

# assert that a valid taxi trip fare is between $3 and $200
df = df.filter(col('fare_amount')>=3)
df = df.filter(col('fare_amount')<=200)

# assert that a valid taxi trip distance is between 1 mile and 50 miles
df = df.filter(col('trip_distance')>=1)
df = df.filter(col('trip_distance')<=50)

# assert that a valid taxi trip toll amount is less than $3
df = df.filter(col('tolls_amount')<=3)

# keep only the columns needed for this problem
df = df.select('trip_time_in_secs', 'trip_distance', 'fare_amount', 'tolls_amount', 'total_amount')

# output to csv
# df.toPandas().to_csv(f'{OUTPUT_DIR}/cleaned_taxi_data.csv', index=False)

### Data visualization