In [1]:
import pyspark

In [2]:
%load_ext jupyter_black
%load_ext autoreload
%autoreload 2

In [3]:
pyspark.__version__

'3.4.1'

https://spark.apache.org/docs/3.4.1/quick-start.html#self-contained-applications

In [4]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType
from pyspark.sql.window import Window
from pyspark.sql import functions as F
from datetime import datetime

In [5]:
import os

os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"

In [6]:
spark = (
    SparkSession.builder.appName("SimpleApp")
    .config("spark.executor.memory", "4g")
    .config("spark.driver.memory", "4g")
    .config("spark.sql.execution.arrow.pyspark.enabled", "true")
    .master("local[4]")
    .getOrCreate()
)

In [7]:
471859200 / 268435456.0

1.7578125

When an app is created, Spark UI is available (default port 4040).

http://localhost:32769/jobs/

In [8]:
spark.conf.get("spark.executor.memory")

'4g'

In [9]:
spark.sparkContext._conf.getAll()

[('spark.app.submitTime', '1697364616290'),
 ('spark.app.startTime', '1697364616451'),
 ('spark.driver.port', '37787'),
 ('spark.master', 'local[4]'),
 ('spark.app.name', 'SimpleApp'),
 ('spark.driver.memory', '4g'),
 ('spark.executor.memory', '4g'),
 ('spark.executor.id', 'driver'),
 ('spark.app.id', 'local-1697364617116'),
 ('spark.driver.extraJavaOptions',
  '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.acti

In [10]:
spark.conf.get("spark.master")

'local[4]'

In [11]:
! du -h data/nyc_taxi/fhvhv_*

289M	data/nyc_taxi/fhvhv_tripdata_2021-02.parquet
352M	data/nyc_taxi/fhvhv_tripdata_2021-03.parquet
352M	data/nyc_taxi/fhvhv_tripdata_2021-04.parquet
370M	data/nyc_taxi/fhvhv_tripdata_2021-05.parquet
376M	data/nyc_taxi/fhvhv_tripdata_2021-06.parquet
378M	data/nyc_taxi/fhvhv_tripdata_2021-07.parquet
365M	data/nyc_taxi/fhvhv_tripdata_2021-08.parquet
376M	data/nyc_taxi/fhvhv_tripdata_2021-09.parquet
411M	data/nyc_taxi/fhvhv_tripdata_2021-10.parquet
393M	data/nyc_taxi/fhvhv_tripdata_2021-11.parquet
392M	data/nyc_taxi/fhvhv_tripdata_2021-12.parquet
358M	data/nyc_taxi/fhvhv_tripdata_2022-01.parquet
389M	data/nyc_taxi/fhvhv_tripdata_2022-02.parquet
450M	data/nyc_taxi/fhvhv_tripdata_2022-03.parquet
435M	data/nyc_taxi/fhvhv_tripdata_2022-04.parquet
447M	data/nyc_taxi/fhvhv_tripdata_2022-05.parquet
437M	data/nyc_taxi/fhvhv_tripdata_2022-06.parquet
424M	data/nyc_taxi/fhvhv_tripdata_2022-07.parquet
417M	data/nyc_taxi/fhvhv_tripdata_2022-08.parquet
437M	data/nyc_taxi/fhvhv_tripdata_2022-09.parquet


In [12]:
sdf = spark.read.parquet("data/nyc_taxi/fhvhv_tripdata_2021-0*")

In [13]:
sdf.printSchema()

root
 |-- hvfhs_license_num: string (nullable = true)
 |-- dispatching_base_num: string (nullable = true)
 |-- originating_base_num: string (nullable = true)
 |-- request_datetime: timestamp_ntz (nullable = true)
 |-- on_scene_datetime: timestamp_ntz (nullable = true)
 |-- pickup_datetime: timestamp_ntz (nullable = true)
 |-- dropoff_datetime: timestamp_ntz (nullable = true)
 |-- PULocationID: long (nullable = true)
 |-- DOLocationID: long (nullable = true)
 |-- trip_miles: double (nullable = true)
 |-- trip_time: long (nullable = true)
 |-- base_passenger_fare: double (nullable = true)
 |-- tolls: double (nullable = true)
 |-- bcf: double (nullable = true)
 |-- sales_tax: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- airport_fee: double (nullable = true)
 |-- tips: double (nullable = true)
 |-- driver_pay: double (nullable = true)
 |-- shared_request_flag: string (nullable = true)
 |-- shared_match_flag: string (nullable = true)
 |-- access_a_ride_f

## Windows functions over

In [14]:
df = spark.createDataFrame(
    [
        (1, 2, "a"),
        (2, 2, "a"),
        (1, 3, "b"),
        (3, 2, "a"),
        (4, 3, "a"),
        (5, 3, "a"),
        (2, 3, "b"),
    ],
    ["time", "value", "class"],
)

windowval = (
    Window.partitionBy("class").orderBy("time").rangeBetween(-1, Window.currentRow)
)
df_w_cumsum = df.withColumn("cum_sum", F.sum("value").over(windowval))
df_w_cumsum.show()

+----+-----+-----+-------+
|time|value|class|cum_sum|
+----+-----+-----+-------+
|   1|    2|    a|      2|
|   2|    2|    a|      4|
|   3|    2|    a|      4|
|   4|    3|    a|      5|
|   5|    3|    a|      6|
|   1|    3|    b|      3|
|   2|    3|    b|      6|
+----+-----+-----+-------+



## Pnads-on-Spark

In [15]:
import pyspark.pandas as ps
import pandas as pd

In [16]:
df = pd.DataFrame({"a": [1, 2, 3, 4]})
sum(df.a)

10

In [17]:
ps_test_df = ps.DataFrame({"id": range(10)})

In [18]:
ps_test_df["id"].max()

9

In [19]:
psdf = sdf.pandas_api()

In [20]:
psdf["hvfhs_license_num"].count()

114046694

In [21]:
psdf["dispatching_base_num"].head()

0    B02764
1    B02764
2    B02510
3    B02510
4    B02872
Name: dispatching_base_num, dtype: object

In [22]:
%%time
res = (
    psdf.groupby("dispatching_base_num")["dispatching_base_num"]
    .count()
    .sort_values(ascending=False)
)

CPU times: user 25.2 ms, sys: 3.39 ms, total: 28.5 ms
Wall time: 108 ms


In [23]:
res.head()

dispatching_base_num
B02510    31052693
B02764    10320707
B02872     8742376
B02875     7078022
B02765     5325306
Name: dispatching_base_num, dtype: int64

In [30]:
psdf_top = (
    psdf[["hvfhs_license_num", "trip_miles", "tips"]]
    .groupby("hvfhs_license_num")
    .agg(["min", "max", "sum"])
    .sort_values([("tips", "sum")], ascending=False)
)

In [31]:
psdf_top.spark.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [__index_level_0__#965, (trip_miles, min)#966, (trip_miles, max)#967, (trip_miles, sum)#968, (tips, min)#969, (tips, max)#970, (tips, sum)#971]
   +- Sort [(tips, sum)#971 DESC NULLS LAST, __natural_order__#1023L ASC NULLS FIRST], true, 0
      +- Exchange rangepartitioning((tips, sum)#971 DESC NULLS LAST, __natural_order__#1023L ASC NULLS FIRST, 200), ENSURE_REQUIREMENTS, [plan_id=448]
         +- HashAggregate(keys=[__index_level_0__#965], functions=[min(trip_miles#9), max(trip_miles#9), sum(trip_miles#9), min(tips#17), max(tips#17), sum(tips#17)])
            +- Exchange hashpartitioning(__index_level_0__#965, 200), ENSURE_REQUIREMENTS, [plan_id=445]
               +- HashAggregate(keys=[__index_level_0__#965], functions=[partial_min(trip_miles#9), partial_max(trip_miles#9), partial_sum(trip_miles#9), partial_min(tips#17), partial_max(tips#17), partial_sum(tips#17)])
                  +- Project [hvfhs_license_num#0 

In [32]:
%%time
psdf_top.head(10)

CPU times: user 12.1 ms, sys: 946 µs, total: 13 ms
Wall time: 32.3 ms


Unnamed: 0_level_0,trip_miles,trip_miles,trip_miles,tips,tips,tips
Unnamed: 0_level_1,min,max,sum,min,max,sum
hvfhs_license_num,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
HV0003,0.0,568.64,390435500.0,0.0,1000.0,59872190.0
HV0005,0.0,392.852,155794400.0,0.0,200.0,30142900.0
HV0004,0.0,106.19,3972909.0,0.0,90.0,392198.8


In [33]:
%%time
psdf.count()

CPU times: user 195 ms, sys: 65.7 ms, total: 261 ms
Wall time: 29.6 s


hvfhs_license_num       114046694
dispatching_base_num    114046694
originating_base_num     81745554
request_datetime        114046693
on_scene_datetime        81757624
pickup_datetime         114046694
dropoff_datetime        114046694
PULocationID            114046694
DOLocationID            114046694
trip_miles              114046694
trip_time               114046694
base_passenger_fare     114046694
tolls                   114046694
bcf                     114046694
sales_tax               114046694
congestion_surcharge    114046694
airport_fee              89452905
tips                    114046694
driver_pay              114046694
shared_request_flag     114046694
shared_match_flag       114046694
access_a_ride_flag      114046694
wav_request_flag        114046694
wav_match_flag          114046694
dtype: int64

In [34]:
%%time
psdf["trip_miles"].max()

CPU times: user 10.3 ms, sys: 8.58 ms, total: 18.8 ms
Wall time: 1.3 s


568.64

In [35]:
ps.set_option("display.max_rows", 101)

In [36]:
from pyspark.sql.functions import date_format

In [37]:
req_dt = (
    sdf.select(["request_datetime", "trip_miles"])
    .withColumn("MM", date_format("request_datetime", "yyyy-MM"))
    .show()
)

+-------------------+----------+-------+
|   request_datetime|trip_miles|     MM|
+-------------------+----------+-------+
|2021-01-31 23:59:00|      2.06|2021-01|
|2021-02-01 00:13:35|      3.15|2021-02|
|2021-02-01 00:12:55|     1.776|2021-02|
|2021-02-01 00:36:01|    13.599|2021-02|
|2021-01-31 23:57:50|      2.62|2021-01|
|2021-02-01 00:11:48|      6.89|2021-02|
|2021-02-01 00:39:45|      4.26|2021-02|
|2021-01-31 23:55:59|      2.95|2021-01|
|2021-02-01 00:27:54|      3.41|2021-02|
|2021-01-31 23:56:04|    15.998|2021-01|
|2021-02-01 00:46:24|     2.354|2021-02|
|2021-01-31 23:59:26|     9.643|2021-01|
|2021-02-01 00:31:01|    10.231|2021-02|
|2021-01-31 23:58:59|      2.69|2021-01|
|2021-02-01 00:19:07|       0.6|2021-02|
|2021-02-01 00:25:36|      8.37|2021-02|
|2021-02-01 00:02:23|      1.88|2021-02|
|2021-02-01 00:31:11|      8.04|2021-02|
|2021-01-31 23:45:24|      4.29|2021-01|
|2021-02-01 00:08:40|      1.53|2021-02|
+-------------------+----------+-------+
only showing top

In [38]:
sdf.select

<bound method DataFrame.select of DataFrame[hvfhs_license_num: string, dispatching_base_num: string, originating_base_num: string, request_datetime: timestamp_ntz, on_scene_datetime: timestamp_ntz, pickup_datetime: timestamp_ntz, dropoff_datetime: timestamp_ntz, PULocationID: bigint, DOLocationID: bigint, trip_miles: double, trip_time: bigint, base_passenger_fare: double, tolls: double, bcf: double, sales_tax: double, congestion_surcharge: double, airport_fee: double, tips: double, driver_pay: double, shared_request_flag: string, shared_match_flag: string, access_a_ride_flag: string, wav_request_flag: string, wav_match_flag: string]>

In [39]:
from pyspark.sql.types import StructType, StructField, DoubleType

## applyInPandas (Apache Arrow in PySpark)
https://spark.apache.org/docs/latest/api/python/user_guide/sql/arrow_pandas.html

### Pandas UDFs (a.k.a. Vectorized UDFs)

In [40]:
%%time
schema = StructType(sdf.schema.fields + [StructField("v", DoubleType())])


def vector_normalize(values):
    v1 = values.driver_pay
    values["v"] = (v1 - v1.mean()) / v1.std()
    return values


group_columns = ["hvfhs_license_num", "driver_pay"]
df_pandas_norm = sdf.groupby(*group_columns).applyInPandas(
    vector_normalize, schema=schema
)
df_pandas_norm.show()

+-----------------+--------------------+--------------------+-------------------+-------------------+-------------------+-------------------+------------+------------+----------+---------+-------------------+-----+---+---------+--------------------+-----------+-----+----------+-------------------+-----------------+------------------+----------------+--------------+----+
|hvfhs_license_num|dispatching_base_num|originating_base_num|   request_datetime|  on_scene_datetime|    pickup_datetime|   dropoff_datetime|PULocationID|DOLocationID|trip_miles|trip_time|base_passenger_fare|tolls|bcf|sales_tax|congestion_surcharge|airport_fee| tips|driver_pay|shared_request_flag|shared_match_flag|access_a_ride_flag|wav_request_flag|wav_match_flag|   v|
+-----------------+--------------------+--------------------+-------------------+-------------------+-------------------+-------------------+------------+------------+----------+---------+-------------------+-----+---+---------+--------------------+-----

In [44]:
sdf.limit(10).toPandas()

Unnamed: 0,hvfhs_license_num,dispatching_base_num,originating_base_num,request_datetime,on_scene_datetime,pickup_datetime,dropoff_datetime,PULocationID,DOLocationID,trip_miles,...,sales_tax,congestion_surcharge,airport_fee,tips,driver_pay,shared_request_flag,shared_match_flag,access_a_ride_flag,wav_request_flag,wav_match_flag
0,HV0003,B02764,B02764,2021-01-31 23:59:00,2021-02-01 00:10:19,2021-02-01 00:10:40,2021-02-01 00:21:09,35,39,2.06,...,1.52,0.0,,0.0,9.79,N,N,,N,N
1,HV0003,B02764,B02764,2021-02-01 00:13:35,2021-02-01 00:25:23,2021-02-01 00:27:23,2021-02-01 00:44:01,39,35,3.15,...,2.85,0.0,,0.0,24.01,N,N,,N,N
2,HV0005,B02510,,2021-02-01 00:12:55,NaT,2021-02-01 00:28:38,2021-02-01 00:38:27,39,91,1.776,...,1.12,0.0,,0.0,6.91,N,N,N,N,N
3,HV0005,B02510,,2021-02-01 00:36:01,NaT,2021-02-01 00:43:37,2021-02-01 01:23:20,91,228,13.599,...,2.91,0.0,,7.0,35.05,N,N,N,N,N
4,HV0003,B02872,B02872,2021-01-31 23:57:50,2021-02-01 00:08:25,2021-02-01 00:08:42,2021-02-01 00:17:57,126,250,2.62,...,1.38,0.0,,0.0,8.53,N,N,,N,N
5,HV0003,B02872,B02872,2021-02-01 00:11:48,2021-02-01 00:24:25,2021-02-01 00:26:02,2021-02-01 00:42:51,208,243,6.89,...,1.77,0.0,,0.0,16.05,N,N,,N,N
6,HV0003,B02872,B02872,2021-02-01 00:39:45,2021-02-01 00:44:57,2021-02-01 00:45:50,2021-02-01 01:02:50,243,220,4.26,...,3.76,0.0,,0.0,25.42,N,N,,N,N
7,HV0003,B02764,B02764,2021-01-31 23:55:59,2021-02-01 00:04:42,2021-02-01 00:06:42,2021-02-01 00:31:50,49,37,2.95,...,2.4,0.0,,0.0,22.29,N,N,,N,N
8,HV0003,B02764,B02764,2021-02-01 00:27:54,2021-02-01 00:33:12,2021-02-01 00:34:34,2021-02-01 00:58:13,37,76,3.41,...,2.03,0.0,,0.0,23.77,N,N,,N,N
9,HV0005,B02510,,2021-01-31 23:56:04,NaT,2021-02-01 00:03:43,2021-02-01 00:39:37,80,241,15.998,...,4.44,0.0,,0.0,35.8,N,N,N,N,N


In [41]:
from pyspark.sql.functions import col, desc

In [42]:
sdf.where(sdf.sales_tax == 4.44).count()

62415

In [43]:
sdf.select("hvfhs_license_num").show()

+-----------------+
|hvfhs_license_num|
+-----------------+
|           HV0003|
|           HV0003|
|           HV0005|
|           HV0005|
|           HV0003|
|           HV0003|
|           HV0003|
|           HV0003|
|           HV0003|
|           HV0005|
|           HV0005|
|           HV0005|
|           HV0005|
|           HV0003|
|           HV0003|
|           HV0003|
|           HV0003|
|           HV0003|
|           HV0004|
|           HV0004|
+-----------------+
only showing top 20 rows



In [12]:
import os
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"

In [22]:
sdf.createOrReplaceTempView("hvfhs")

In [23]:
spark.sql(""" SELECT  hvfhs_license_num, sales_tax FROM hvfhs 
          WHERE hvfhs_license_num = 'HV0004' """) \
     .show(5)

+-----------------+---------+
|hvfhs_license_num|sales_tax|
+-----------------+---------+
|           HV0004|     1.42|
|           HV0004|     1.04|
|           HV0004|     2.55|
|           HV0004|      2.1|
|           HV0004|     3.53|
+-----------------+---------+
only showing top 5 rows



In [25]:
spark.sql(""" SELECT count(*) FROM hvfhs 
          WHERE hvfhs_license_num = 'HV0004' """) \
     .show(5)

+--------+
|count(1)|
+--------+
|  781804|
+--------+

