In [1]:
import requests
import shutil
import os
import time
import random
import concurrent.futures
import findspark
findspark.init()

from pyspark.sql import SparkSession
from pyspark.sql.types import (
    StructType,
    StructField,
    LongType,
    TimestampType,
    DoubleType,
    StringType,
    IntegerType,
    TimestampNTZType
)
from pyspark.sql.functions import col
from pyspark import SparkConf
from py4j.java_gateway import Py4JJavaError

In [1]:
spark_conf = SparkConf()
spark_conf.set("spark.executor.memory", "4g")
spark_conf.set("spark.driver.memory", "2g")
spark_conf.set("spark.network.timeout", "600s")
spark_conf.set("spark.executor.instances", "4")
spark_conf.set("spark.executor.cores", "4")
spark_conf.set("spark.default.parallelism", "8")
spark_conf.set("spark.sql.shuffle.partitions", "8")
spark_conf.set("spark.sql.parquet.enableVectorizedReader", "true")

Note: you may need to restart the kernel to use updated packages.


In [6]:
with SparkSession.builder.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer").config(conf=spark_conf).getOrCreate() as spark:

    def delete_file(file_path):
        try:
            os.remove(file_path)
            print(f"File {file_path} deleted successfuly.")
        except OSError as e:
            print(f"Error deleting file {file_path}: {e}")
    
    def create_dir(dir):
        if not os.path.exists(dir):
            os.makedirs(dir)

    def copy_file(source_path, destination_path):
        shutil.copy2(source_path, destination_path)

    def file_exists(file_path):
        return os.path.exists(file_path)

    def is_directory_empty(directory_path):
        return len(os.listdir(directory_path)) == 0
    
    def download_tripdata(table_name, year, month):
        time.sleep(random.randint(0, 15))
        url = f"https://d37ci6vzurychx.cloudfront.net/trip-data/{table_name}_tripdata_{year}-{month}.parquet"
        create_dir(f"data/landing/{table_name}/{year}/{month}")

        while True:
            try:
                if file_exists(f'data/landing/{table_name}/{year}/{month}/{table_name}_{year}-{month}.parquet'):
                    print("File already downloaded. Skipping download execution")
                    break
                with requests.get(url, stream=True) as r:
                    r.raise_for_status()
                    with open(f'data/landing/{table_name}/{year}/{month}/{table_name}_{year}-{month}.parquet', 'wb') as out_file:
                        shutil.copyfileobj(r.raw, out_file)
                break
            except (Exception, requests.exceptions.RequestException) as e:
                print(f"Request failed with: {e}. Retrying...")
                time.sleep(90)


    def download_and_normalize_data(table_name, year, month):
        
        print(f"Executing: {table_name}-{year}-{month}")
        download_tripdata(table_name, year, month)
        parquet_source_location = f"data/landing/{table_name}/{year}/{month}"
        parquet_load_location = f"data/raw/{table_name}/{year}/{month}"
        create_dir(parquet_load_location)

        try:
            df = spark.read.parquet(parquet_source_location)
            schema_correct = all([field.name in df.columns and field.dataType == df.schema[field.name].dataType for field in get_schema(table_name).fields])
            if not schema_correct:
                print("Falling back to inferring schema.")
                
                schema = get_schema(table_name)
                for field in schema.fields:
                    column_name = field.name
                    data_type = field.dataType
                    df = df.withColumn(column_name, col(column_name).cast(data_type))
    
                df.write.mode('overwrite').parquet(parquet_load_location)
            else:
                file = f"{table_name}_{year}-{month}.parquet"
                if is_directory_empty(parquet_load_location):
                    copy_file(parquet_source_location+f"/{file}", parquet_load_location+f"/{file}")
                else:
                    print(f"Files already loaded for {file}")
        except Py4JJavaError as e:
            file = f"{table_name}_{year}-{month}.parquet"
            delete_file(parquet_source_location+f"/{file}")
            download_and_normalize_data(table_name, year, month)
        except Exception as e:
            print(f"Error processing file: {parquet_source_location} with {e}")


    def get_schema(table_name):
        if table_name == "yellow":
            return StructType([
                StructField("VendorID", LongType(), True),
                StructField("tpep_pickup_datetime", TimestampNTZType(), True),
                StructField("tpep_dropoff_datetime", TimestampNTZType(), True),
                StructField("passenger_count", LongType(), True),
                StructField("trip_distance", DoubleType(), True),
                StructField("RatecodeID", LongType(), True),
                StructField("store_and_fwd_flag", StringType(), True),
                StructField("PULocationID", LongType(), True),
                StructField("DOLocationID", LongType(), True),
                StructField("payment_type", LongType(), True),
                StructField("fare_amount", DoubleType(), True),
                StructField("extra", DoubleType(), True),
                StructField("mta_tax", DoubleType(), True),
                StructField("tip_amount", DoubleType(), True),
                StructField("tolls_amount", DoubleType(), True),
                StructField("improvement_surcharge", DoubleType(), True),
                StructField("total_amount", DoubleType(), True),
                StructField("congestion_surcharge", DoubleType(), True),
                StructField("airport_fee", DoubleType(), True),
            ])
        elif table_name == "green":
            return StructType([
                StructField("VendorID", LongType(), True),
                StructField("lpep_pickup_datetime", TimestampNTZType(), True),
                StructField("lpep_dropoff_datetime", TimestampNTZType(), True),
                StructField("passenger_count", LongType(), True),
                StructField("trip_distance", DoubleType(), True),
                StructField("RatecodeID", LongType(), True),
                StructField("store_and_fwd_flag", StringType(), True),
                StructField("PULocationID", LongType(), True),
                StructField("DOLocationID", LongType(), True),
                StructField("payment_type", LongType(), True),
                StructField("fare_amount", DoubleType(), True),
                StructField("extra", DoubleType(), True),
                StructField("mta_tax", DoubleType(), True),
                StructField("tip_amount", DoubleType(), True),
                StructField("tolls_amount", DoubleType(), True),
                StructField("ehail_fee", IntegerType(), True),
                StructField("improvement_surcharge", DoubleType(), True),
                StructField("total_amount", DoubleType(), True),
                StructField("trip_type", DoubleType(), True),
                StructField("congestion_surcharge", DoubleType(), True)
            ])
        else:
            raise ValueError(f"Table name '{table_name}' is not supported.")

    demo_download_list = [
        ["yellow", "2018", [str(x).zfill(2) for x in range(1, 13)]],
        ["yellow", "2019", [str(x).zfill(2) for x in range(1, 13)]],
        ["yellow", "2020", [str(x).zfill(2) for x in range(1, 13)]],
        ["yellow", "2021", [str(x).zfill(2) for x in range(1, 13)]],
        ["yellow", "2022", [str(x).zfill(2) for x in range(1, 13)]],
        ["green", "2018", [str(x).zfill(2) for x in range(1, 13)]],
        ["green", "2019", [str(x).zfill(2) for x in range(1, 13)]],
        ["green", "2020", [str(x).zfill(2) for x in range(1, 13)]],
        ["green", "2021", [str(x).zfill(2) for x in range(1, 13)]],
        ["green", "2022", [str(x).zfill(2) for x in range(1, 13)]]
    ]

    def parallel_download_and_normalize_data(demo_download):
        table_name, year, months = demo_download
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = []
            for month in months:
                future = executor.submit(download_and_normalize_data, table_name, year, month)
                futures.append(future)
            concurrent.futures.wait(futures)

    for demo_download in demo_download_list:
        parallel_download_and_normalize_data(demo_download)

Executing: green-2018-01
Executing: green-2018-02
Executing: green-2018-03
Executing: green-2018-04
Executing: green-2018-05
Executing: green-2018-06
Executing: green-2018-07
Executing: green-2018-08
Executing: green-2018-09
File already downloaded. Skipping download execution
Falling back to inferring schema.
File already downloaded. Skipping download execution
Falling back to inferring schema.
File already downloaded. Skipping download execution
File already downloaded. Skipping download execution
Falling back to inferring schema.
Falling back to inferring schema.
File already downloaded. Skipping download execution
Falling back to inferring schema.
File already downloaded. Skipping download execution
File already downloaded. Skipping download execution
File already downloaded. Skipping download execution
Falling back to inferring schema.
Falling back to inferring schema.
Falling back to inferring schema.
Executing: green-2018-10
Executing: green-2018-11
Executing: green-2018-12
File

In [4]:
import requests
import shutil
import os
import findspark
import time
import random
findspark.init()

from pyspark.sql import SparkSession
from pyspark.sql.types import (
    StructType,
    StructField,
    LongType,
    TimestampType,
    DoubleType,
    StringType,
    IntegerType,
    TimestampNTZType
)
from pyspark.sql.functions import col
from pyspark import SparkConf
from py4j.java_gateway import Py4JJavaError

spark = SparkSession.builder.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer").getOrCreate()
spark.conf.set("spark.sql.parquet.enableVectorizedReader", "true")
df = spark.read.parquet("data/landing/green/2022/01")
df.printSchema()


def delete_file(file_path):
    try:
        os.remove(file_path)
        print(f"File {file_path} deleted successfuly.")
    except OSError as e:
        print(f"Error deleting file {file_path}: {e}")

def create_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir)

def copy_file(source_path, destination_path):
    shutil.copy2(source_path, destination_path)

def file_exists(file_path):
    return os.path.exists(file_path)


def download_tripdata(table_name, year, month):
    time.sleep(random.randint(0, 15))
    url = f"https://d37ci6vzurychx.cloudfront.net/trip-data/{table_name}_tripdata_{year}-{month}.parquet"
    create_dir(f"data/landing/{table_name}/{year}/{month}")

    while True:
        try:
            if file_exists(f'data/landing/{table_name}/{year}/{month}/{table_name}_{year}-{month}.parquet'):
                print("File already downloaded. Skipping download execution")
                break
            with requests.get(url, stream=True) as r:
                r.raise_for_status()
                with open(f'data/landing/{table_name}/{year}/{month}/{table_name}_{year}-{month}.parquet', 'wb') as out_file:
                    shutil.copyfileobj(r.raw, out_file)
            break
        except (Exception, requests.exceptions.RequestException) as e:
            print(f"Request failed with: {e}. Retrying...")
            time.sleep(90)

def get_schema(table_name):
    if table_name == "yellow":
        return StructType([
            StructField("VendorID", LongType(), True),
            StructField("tpep_pickup_datetime", TimestampNTZType(), True),
            StructField("tpep_dropoff_datetime", TimestampNTZType(), True),
            StructField("passenger_count", LongType(), True),
            StructField("trip_distance", DoubleType(), True),
            StructField("RatecodeID", LongType(), True),
            StructField("store_and_fwd_flag", StringType(), True),
            StructField("PULocationID", LongType(), True),
            StructField("DOLocationID", LongType(), True),
            StructField("payment_type", LongType(), True),
            StructField("fare_amount", DoubleType(), True),
            StructField("extra", DoubleType(), True),
            StructField("mta_tax", DoubleType(), True),
            StructField("tip_amount", DoubleType(), True),
            StructField("tolls_amount", DoubleType(), True),
            StructField("improvement_surcharge", DoubleType(), True),
            StructField("total_amount", DoubleType(), True),
            StructField("congestion_surcharge", DoubleType(), True),
            StructField("airport_fee", DoubleType(), True),
        ])
    elif table_name == "green":
        return StructType([
            StructField("VendorID", LongType(), True),
            StructField("tpep_pickup_datetime", TimestampNTZType(), True),
            StructField("tpep_dropoff_datetime", TimestampNTZType(), True),
            StructField("passenger_count", LongType(), True),
            StructField("trip_distance", DoubleType(), True),
            StructField("RatecodeID", LongType(), True),
            StructField("store_and_fwd_flag", StringType(), True),
            StructField("PULocationID", LongType(), True),
            StructField("DOLocationID", LongType(), True),
            StructField("payment_type", LongType(), True),
            StructField("fare_amount", DoubleType(), True),
            StructField("extra", DoubleType(), True),
            StructField("mta_tax", DoubleType(), True),
            StructField("tip_amount", DoubleType(), True),
            StructField("tolls_amount", DoubleType(), True),
            StructField("ehail_fee", IntegerType(), True),
            StructField("improvement_surcharge", DoubleType(), True),
            StructField("total_amount", DoubleType(), True),
            StructField("trip_type", DoubleType(), True),
            StructField("congestion_surcharge", DoubleType(), True)
        ])
    else:
        raise ValueError(f"Table name '{table_name}' is not supported.")

def download_and_normalize_data(table_name, year, month):
    
    print(f"Executing: {table_name}-{year}-{month}")
    download_tripdata(table_name, year, month)
    parquet_source_location = f"data/landing/{table_name}/{year}/{month}"
    print(parquet_source_location)
    parquet_load_location = f"data/raw/{table_name}/{year}/{month}"
    create_dir(parquet_load_location)

    try:
        df = spark.read.parquet(parquet_source_location)
        schema_correct = all([field.name in df.columns and field.dataType == df.schema[field.name].dataType for field in get_schema(table_name).fields])
        print(schema_correct)
        df.printSchema()
        return
        if not schema_correct:
            print("Falling back to inferring schema.")
            
            schema = get_schema(table_name)
            for field in schema.fields:
                column_name = field.name
                data_type = field.dataType
                df = df.withColumn(column_name, col(column_name).cast(data_type))

            df.write.mode('overwrite').parquet(parquet_load_location)
        else:
            file = f"{table_name}_{year}-{month}.parquet"
            copy_file(parquet_source_location+f"/{file}", parquet_load_location+f"/{file}")
    except Py4JJavaError as e:
        print(e)
        file = f"{table_name}_{year}-{month}.parquet"
        delete_file(parquet_source_location+f"/{file}")
        download_and_normalize_data(table_name, year, month)
    except Exception as e:
        print(f"Error processing file: {parquet_source_location} with {e}")


download_and_normalize_data("green", "2022", "01")

root
 |-- VendorID: long (nullable = true)
 |-- lpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- lpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- RatecodeID: double (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- passenger_count: double (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- ehail_fee: integer (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- payment_type: double (nullable = true)
 |-- trip_type: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)

Executing: green-2022-01
File already downloaded. Skipping download execution
data/landi