https://github.com/DataTalksClub/data-engineering-zoomcamp/blob/main/05-batch/code/04_pyspark.ipynb

In [1]:
import pyspark
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder \
    .master("local[*]") \
    .appName('test') \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/03/05 07:50:52 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:
!curl -LO https://github.com/DataTalksClub/nyc-tlc-data/releases/download/fhvhv/fhvhv_tripdata_2021-01.csv.gz --output-dir ../data/

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
100  123M  100  123M    0     0  5796k      0  0:00:21  0:00:21 --:--:-- 9108k


In [6]:
!gzip -dk ../data/fhvhv_tripdata_2021-01.csv.gz

../data/fhvhv_tripdata_2021-01.csv already exists -- do you wish to overwrite (y or n)? ^C


In [3]:
!wc -l ../data/fhvhv_tripdata_2021-01.csv

 11908469 ../data/fhvhv_tripdata_2021-01.csv


In [10]:
pyspark_df = spark.read.option('header', 'true').csv('../data/fhvhv_tripdata_2021-01.csv', inferSchema=True)

                                                                                

In [12]:
pyspark_df.schema

StructType([StructField('hvfhs_license_num', StringType(), True), StructField('dispatching_base_num', StringType(), True), StructField('pickup_datetime', TimestampType(), True), StructField('dropoff_datetime', TimestampType(), True), StructField('PULocationID', IntegerType(), True), StructField('DOLocationID', IntegerType(), True), StructField('SR_Flag', IntegerType(), True)])

In [11]:
pyspark_df.show()

+-----------------+--------------------+-------------------+-------------------+------------+------------+-------+
|hvfhs_license_num|dispatching_base_num|    pickup_datetime|   dropoff_datetime|PULocationID|DOLocationID|SR_Flag|
+-----------------+--------------------+-------------------+-------------------+------------+------------+-------+
|           HV0003|              B02682|2021-01-01 00:33:44|2021-01-01 00:49:07|         230|         166|   NULL|
|           HV0003|              B02682|2021-01-01 00:55:19|2021-01-01 01:18:21|         152|         167|   NULL|
|           HV0003|              B02764|2021-01-01 00:23:56|2021-01-01 00:38:05|         233|         142|   NULL|
|           HV0003|              B02764|2021-01-01 00:42:51|2021-01-01 00:45:50|         142|         143|   NULL|
|           HV0003|              B02764|2021-01-01 00:48:14|2021-01-01 01:08:42|         143|          78|   NULL|
|           HV0005|              B02510|2021-01-01 00:06:59|2021-01-01 00:43:01|

In [11]:
!head -n 1001 ../data/fhvhv_tripdata_2021-01.csv > ../data/fhvhv_tripdata_2021-01_head.csv

In [12]:
!head -n 10 ../data/fhvhv_tripdata_2021-01_head.csv

hvfhs_license_num,dispatching_base_num,pickup_datetime,dropoff_datetime,PULocationID,DOLocationID,SR_Flag
HV0003,B02682,2021-01-01 00:33:44,2021-01-01 00:49:07,230,166,
HV0003,B02682,2021-01-01 00:55:19,2021-01-01 01:18:21,152,167,
HV0003,B02764,2021-01-01 00:23:56,2021-01-01 00:38:05,233,142,
HV0003,B02764,2021-01-01 00:42:51,2021-01-01 00:45:50,142,143,
HV0003,B02764,2021-01-01 00:48:14,2021-01-01 01:08:42,143,78,
HV0005,B02510,2021-01-01 00:06:59,2021-01-01 00:43:01,88,42,
HV0005,B02510,2021-01-01 00:50:00,2021-01-01 01:04:57,42,151,
HV0003,B02764,2021-01-01 00:14:30,2021-01-01 00:50:27,71,226,
HV0003,B02875,2021-01-01 00:22:54,2021-01-01 00:30:20,112,255,


In [13]:
!wc -l ../data/fhvhv_tripdata_2021-01_head.csv

    1001 ../data/fhvhv_tripdata_2021-01_head.csv


In [7]:
import pandas as pd

In [8]:
pandas_df = pd.read_csv('../data/fhvhv_tripdata_2021-01_head.csv')

In [9]:
pandas_df.dtypes

hvfhs_license_num        object
dispatching_base_num     object
pickup_datetime          object
dropoff_datetime         object
PULocationID              int64
DOLocationID              int64
SR_Flag                 float64
dtype: object

In [20]:
spark.createDataFrame(pandas_df).schema

StructType([StructField('hvfhs_license_num', StringType(), True), StructField('dispatching_base_num', StringType(), True), StructField('pickup_datetime', StringType(), True), StructField('dropoff_datetime', StringType(), True), StructField('PULocationID', LongType(), True), StructField('DOLocationID', LongType(), True), StructField('SR_Flag', DoubleType(), True)])

In [13]:
from pyspark.sql import types

In [14]:
schema = types.StructType(
    [
        types.StructField('hvfhs_license_num', types.StringType(), True), 
        types.StructField('dispatching_base_num', types.StringType(), True), 
        types.StructField('pickup_datetime', types.TimestampType(), True), 
        types.StructField('dropoff_datetime', types.TimestampType(), True), 
        types.StructField('PULocationID', types.IntegerType(), True), 
        types.StructField('DOLocationID', types.IntegerType(), True), 
        types.StructField('SR_Flag', types.StringType(), True)
    ]
)

In [15]:
pyspark_df_cust_schema = spark.read \
    .option('header', 'true') \
    .schema(schema) \
    .csv('../data/fhvhv_tripdata_2021-01.csv')

In [17]:
pyspark_df_cust_schema.schema

StructType([StructField('hvfhs_license_num', StringType(), True), StructField('dispatching_base_num', StringType(), True), StructField('pickup_datetime', TimestampType(), True), StructField('dropoff_datetime', TimestampType(), True), StructField('PULocationID', IntegerType(), True), StructField('DOLocationID', IntegerType(), True), StructField('SR_Flag', StringType(), True)])

In [31]:
pyspark_df_partitioned = pyspark_df.repartition(24)

In [36]:
pyspark_df_partitioned.write.parquet('../data/fhvhv_tripdata/2021/01')

24/03/04 19:58:36 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
24/03/04 19:58:38 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
24/03/04 19:58:39 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
24/03/04 19:58:39 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
                                                                                

In [34]:
pandas_df.write.parquet('../data/fhvhv_tripdata_x/2021/01')

24/03/04 19:57:27 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
                                                                                

In [18]:
pyspark_df_from_parquet = spark.read.parquet('../data/fhvhv_tripdata/2021/01')

                                                                                

In [26]:
pyspark_df_from_parquet.printSchema()

root
 |-- hvfhs_license_num: string (nullable = true)
 |-- dispatching_base_num: string (nullable = true)
 |-- pickup_datetime: timestamp (nullable = true)
 |-- dropoff_datetime: timestamp (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- SR_Flag: string (nullable = true)



In [30]:
pyspark_df_from_parquet.select(
    'pickup_datetime', 'dropoff_datetime', 'PULocationID', 'DOLocationID'
).filter(pyspark_df_from_parquet.hvfhs_license_num == 'HV0003').show()

+-------------------+-------------------+------------+------------+
|    pickup_datetime|   dropoff_datetime|PULocationID|DOLocationID|
+-------------------+-------------------+------------+------------+
|2021-01-01 16:17:16|2021-01-01 16:35:20|         153|         167|
|2021-01-04 15:44:57|2021-01-04 15:56:21|         127|         244|
|2021-01-03 14:05:52|2021-01-03 14:20:29|         183|          20|
|2021-01-01 09:27:50|2021-01-01 09:46:50|         162|         106|
|2021-01-04 10:53:41|2021-01-04 11:53:07|          18|          18|
|2021-01-04 22:29:41|2021-01-04 22:51:07|          76|         215|
|2021-01-02 14:52:09|2021-01-02 15:12:01|         220|          32|
|2021-01-01 07:31:49|2021-01-01 07:46:16|         167|          94|
|2021-01-02 22:35:51|2021-01-02 23:07:17|          71|         130|
|2021-01-01 15:38:12|2021-01-01 15:58:56|         222|          61|
|2021-01-05 02:08:44|2021-01-05 02:26:05|         181|          17|
|2021-01-01 03:56:09|2021-01-01 04:06:33|       

In [31]:
from pyspark.sql import functions as F

In [45]:
pyspark_df_from_parquet.withColumn(
    'pickup_date', F.to_date('pickup_datetime')
).show()

+-----------------+--------------------+-------------------+-------------------+------------+------------+-------+-----------+
|hvfhs_license_num|dispatching_base_num|    pickup_datetime|   dropoff_datetime|PULocationID|DOLocationID|SR_Flag|pickup_date|
+-----------------+--------------------+-------------------+-------------------+------------+------------+-------+-----------+
|           HV0003|              B02884|2021-01-01 16:17:16|2021-01-01 16:35:20|         153|         167|   NULL| 2021-01-01|
|           HV0003|              B02882|2021-01-04 15:44:57|2021-01-04 15:56:21|         127|         244|   NULL| 2021-01-04|
|           HV0005|              B02510|2021-01-04 14:29:31|2021-01-04 14:54:56|         138|          49|   NULL| 2021-01-04|
|           HV0003|              B02864|2021-01-03 14:05:52|2021-01-03 14:20:29|         183|          20|   NULL| 2021-01-03|
|           HV0003|              B02867|2021-01-01 09:27:50|2021-01-01 09:46:50|         162|         106|   NU

In [41]:
pyspark_df_from_parquet.withColumns(
    {
        'pickup_date': F.to_date('pickup_datetime'),
        'dropoff_date': F.to_date('dropoff_datetime')
    }
).show()

+-----------------+--------------------+-------------------+-------------------+------------+------------+-------+-----------+------------+
|hvfhs_license_num|dispatching_base_num|    pickup_datetime|   dropoff_datetime|PULocationID|DOLocationID|SR_Flag|pickup_date|dropoff_date|
+-----------------+--------------------+-------------------+-------------------+------------+------------+-------+-----------+------------+
|           HV0003|              B02884|2021-01-01 16:17:16|2021-01-01 16:35:20|         153|         167|   NULL| 2021-01-01|  2021-01-01|
|           HV0003|              B02882|2021-01-04 15:44:57|2021-01-04 15:56:21|         127|         244|   NULL| 2021-01-04|  2021-01-04|
|           HV0005|              B02510|2021-01-04 14:29:31|2021-01-04 14:54:56|         138|          49|   NULL| 2021-01-04|  2021-01-04|
|           HV0003|              B02864|2021-01-03 14:05:52|2021-01-03 14:20:29|         183|          20|   NULL| 2021-01-03|  2021-01-03|
|           HV0003| 

In [42]:
# Dummy user-defined function 
# checks if dispatching_base_num is divisible by 7
def crazy_stuff(base_num):
	num = int(base_num[1:])
	if num % 7 == 0:
		return f's/{num:03x}' 		# return num in hex prepended with 's/'
	else:
		return f'e/{num:03x}' 		# return num in hex prepended with 'e/'		

In [43]:
crazy_stuff('B02875')

'e/b3b'

In [44]:
crazy_stuff_udf = F.udf(crazy_stuff, returnType=types.StringType())

In [None]:
pyspark_df_from_parquet.withColumns(
    {
        'pickup_date': F.to_date('pickup_datetime'),
        'dropoff_date': F.to_date('dropoff_datetime'),
        'base_id': crazy_stuff_udf('dispatching_base_num')
    }
).show()