# PySpark

In [1]:
from pathlib import Path

import pyspark
from pyspark.sql import SparkSession

In [2]:
# create spark session:
spark = SparkSession.builder \
    .master("local[*]") \
    .appName("test") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/03/03 08:14:41 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## Check Spark is working

In [3]:
# connect to file
df = spark.read \
    .option("header", "true") \
    .csv("./data/taxi+_zone_lookup.csv")

type(df)

pyspark.sql.dataframe.DataFrame

In [4]:
df.show(5)

+----------+-------------+--------------------+------------+
|LocationID|      Borough|                Zone|service_zone|
+----------+-------------+--------------------+------------+
|         1|          EWR|      Newark Airport|         EWR|
|         2|       Queens|         Jamaica Bay|   Boro Zone|
|         3|        Bronx|Allerton/Pelham G...|   Boro Zone|
|         4|    Manhattan|       Alphabet City| Yellow Zone|
|         5|Staten Island|       Arden Heights|   Boro Zone|
+----------+-------------+--------------------+------------+
only showing top 5 rows



In [7]:
# save result
df.write.parquet("./data/taxi_zones", mode="overwrite")

In [8]:
ls -l ./data/taxi_zones/

total 16
-rw-r--r--  1 vgarist  staff     0 Mar  3 08:15 _SUCCESS
-rw-r--r--  1 vgarist  staff  5916 Mar  3 08:15 part-00000-8385dc98-c529-47ec-a1c4-a227a1b2520a-c000.snappy.parquet


In [9]:
# could be easily read from parquet:
spark.read.parquet("./data/taxi_zones").show(2)

+----------+-------+--------------+------------+
|LocationID|Borough|          Zone|service_zone|
+----------+-------+--------------+------------+
|         1|    EWR|Newark Airport|         EWR|
|         2| Queens|   Jamaica Bay|   Boro Zone|
+----------+-------+--------------+------------+
only showing top 2 rows



## First look at Pyspark

In [10]:
from pyspark.sql import types

- read data

In [11]:
# FHV trips:
df = spark.read \
    .option("header", "true") \
    .csv("./data/fhvhv_trips/fhvhv_tripdata_2021-06.csv.gz")

In [12]:
df.count()

                                                                                

14961892

In [13]:
df.schema

StructType([StructField('dispatching_base_num', StringType(), True), StructField('pickup_datetime', StringType(), True), StructField('dropoff_datetime', StringType(), True), StructField('PULocationID', StringType(), True), StructField('DOLocationID', StringType(), True), StructField('SR_Flag', StringType(), True), StructField('Affiliated_base_number', StringType(), True)])

In [14]:
# sample from spark to pandas:
df_pandas = df.limit(1000).toPandas()

- update schema

In [15]:
# could move from pandas to Spark
schema = types.StructType([
    types.StructField('dispatching_base_num', types.StringType(), True),
    types.StructField('pickup_datetime', types.TimestampType(), True),
    types.StructField('dropoff_datetime', types.TimestampType(), True),
    types.StructField('PULocationID', types.IntegerType(), True),
    types.StructField('DOLocationID', types.IntegerType(), True),
    types.StructField('SR_Flag', types.StringType(), True),
    types.StructField('Affiliated_base_number', types.StringType(), True),

])

In [16]:
# update FHV trips:
df = spark.read \
    .option("header", "true") \
    .schema(schema) \
    .csv("./data/fhvhv_trips/fhvhv_tripdata_2021-06.csv.gz")

In [17]:
# all columns were parsed successfully:
df.head(2)

[Row(dispatching_base_num='B02764', pickup_datetime=datetime.datetime(2021, 6, 1, 0, 2, 41), dropoff_datetime=datetime.datetime(2021, 6, 1, 0, 7, 46), PULocationID=174, DOLocationID=18, SR_Flag='N', Affiliated_base_number='B02764'),
 Row(dispatching_base_num='B02764', pickup_datetime=datetime.datetime(2021, 6, 1, 0, 16, 16), dropoff_datetime=datetime.datetime(2021, 6, 1, 0, 21, 14), PULocationID=32, DOLocationID=254, SR_Flag='N', Affiliated_base_number='B02764')]

In [18]:
# repartition for further saving
df = df.repartition(numPartitions=12)  # nothing has changed here

In [19]:
# save partitioned DataFrame:
df.write.parquet("./data/fhvhv_trips/2021/06/", mode="overwrite")

[Stage 11:>                                                         (0 + 1) / 1]

23/03/03 08:16:06 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers


                                                                                

In [20]:
ls -lh {"./data/fhvhv_trips/2021/06/"}

total 591872
-rw-r--r--  1 vgarist  staff     0B Mar  3 08:16 _SUCCESS
-rw-r--r--  1 vgarist  staff    24M Mar  3 08:16 part-00000-c18a05d6-ebd2-4e74-ad7a-2613b288c1d7-c000.snappy.parquet
-rw-r--r--  1 vgarist  staff    24M Mar  3 08:16 part-00001-c18a05d6-ebd2-4e74-ad7a-2613b288c1d7-c000.snappy.parquet
-rw-r--r--  1 vgarist  staff    24M Mar  3 08:16 part-00002-c18a05d6-ebd2-4e74-ad7a-2613b288c1d7-c000.snappy.parquet
-rw-r--r--  1 vgarist  staff    24M Mar  3 08:16 part-00003-c18a05d6-ebd2-4e74-ad7a-2613b288c1d7-c000.snappy.parquet
-rw-r--r--  1 vgarist  staff    24M Mar  3 08:16 part-00004-c18a05d6-ebd2-4e74-ad7a-2613b288c1d7-c000.snappy.parquet
-rw-r--r--  1 vgarist  staff    24M Mar  3 08:16 part-00005-c18a05d6-ebd2-4e74-ad7a-2613b288c1d7-c000.snappy.parquet
-rw-r--r--  1 vgarist  staff    24M Mar  3 08:16 part-00006-c18a05d6-ebd2-4e74-ad7a-2613b288c1d7-c000.snappy.parquet
-rw-r--r--  1 vgarist  staff    24M Mar  3 08:16 part-00007-c18a05d6-ebd2-4e74-ad7a-2613b288c1d7-c000

## Spark DataFrames

In [33]:
from pyspark.sql import functions as F

In [23]:
# read data:
df = spark.read.parquet("./data/fhvhv_trips/2021/06/")
df.show(2)

+--------------------+-------------------+-------------------+------------+------------+-------+----------------------+
|dispatching_base_num|    pickup_datetime|   dropoff_datetime|PULocationID|DOLocationID|SR_Flag|Affiliated_base_number|
+--------------------+-------------------+-------------------+------------+------------+-------+----------------------+
|              B02875|2021-06-16 22:08:45|2021-06-16 22:38:10|          48|         181|      N|                B02875|
|              B02875|2021-06-27 08:13:22|2021-06-27 08:16:18|          10|          10|      N|                B02875|
+--------------------+-------------------+-------------------+------------+------------+-------+----------------------+
only showing top 2 rows



In [32]:
# filter data (first TRANSFORMATION, then ACTION)

# Transformations: select, filter, join, groupby
# Actions: show, take, head, write, ...
df \
    .select("pickup_datetime", "dropoff_datetime", "PULocationID", "DOLocationID") \
    .filter(df["Affiliated_base_number"] == "B02875") \
    .show(5)

+-------------------+-------------------+------------+------------+
|    pickup_datetime|   dropoff_datetime|PULocationID|DOLocationID|
+-------------------+-------------------+------------+------------+
|2021-06-16 22:08:45|2021-06-16 22:38:10|          48|         181|
|2021-06-27 08:13:22|2021-06-27 08:16:18|          10|          10|
|2021-06-14 07:12:35|2021-06-14 07:46:52|         112|         166|
|2021-06-06 04:05:40|2021-06-06 04:18:32|          79|         261|
|2021-06-17 09:17:20|2021-06-17 09:48:09|          68|         138|
+-------------------+-------------------+------------+------------+
only showing top 5 rows



In [36]:
# option #1: use predefined functions (F):
df \
    .withColumn("pickup_date", F.to_date(df["pickup_datetime"])) \
    .withColumn("dropoff_date", F.to_date(df["dropoff_datetime"])) \
    .select("pickup_date", "dropoff_date", "PULocationID", "DOLocationID") \
    .show(5)

+-----------+------------+------------+------------+
|pickup_date|dropoff_date|PULocationID|DOLocationID|
+-----------+------------+------------+------------+
| 2021-06-16|  2021-06-16|          48|         181|
| 2021-06-27|  2021-06-27|          10|          10|
| 2021-06-13|  2021-06-13|          89|         189|
| 2021-06-15|  2021-06-15|          36|          82|
| 2021-06-09|  2021-06-09|         254|         254|
+-----------+------------+------------+------------+
only showing top 5 rows



In [39]:
# option #2: use UDFs
def process_location(location_id: int) -> str:
    """Assign City according to location"""
    if 0 < location_id <= 100:
        return "Moscow"
    elif 100 < location_id <= 200:
        return "London"
    else:
        return "Tokyo"


# check UDF works
process_location_udf = F.udf(process_location, returnType=types.StringType())
df \
    .withColumn("pickup_date", F.to_date(df["pickup_datetime"])) \
    .withColumn("dropoff_date", F.to_date(df["dropoff_datetime"])) \
    .withColumn("pickup_location", process_location_udf(df["PULocationID"])) \
    .withColumn("dropoff_location", process_location_udf(df["DOLocationID"])) \
    .select("pickup_date", "dropoff_date", "pickup_location", "dropoff_location") \
    .show(5)

In [42]:
# option #3: SQL queries
df.createOrReplaceTempView(name="fhv_trips")

spark.sql("""
SELECT
    date(pickup_datetime) as pickup_date,
    date(dropoff_datetime) as dropoff_date,
    PULocationID as pickup_location_id,
    DOLocationID as dropoff_location_id,
    lower(Affiliated_base_number) as base_num
FROM fhv_trips
""").show(5)

+-----------+------------+------------------+-------------------+--------+
|pickup_date|dropoff_date|pickup_location_id|dropoff_location_id|base_num|
+-----------+------------+------------------+-------------------+--------+
| 2021-06-16|  2021-06-16|                48|                181|  b02875|
| 2021-06-27|  2021-06-27|                10|                 10|  b02875|
| 2021-06-13|  2021-06-13|                89|                189|    null|
| 2021-06-15|  2021-06-15|                36|                 82|  b02764|
| 2021-06-09|  2021-06-09|               254|                254|    null|
+-----------+------------+------------------+-------------------+--------+
only showing top 5 rows



## SQL with Spark

In [55]:
df_stats = spark.sql("""
SELECT
    date(date_trunc('WEEK', pickup_datetime)) as d_month,
    count(*) as trip_cnt
FROM fhv_trips
WHERE SR_Flag = 'N'
GROUP BY 1
""")

In [56]:
# save query result to file:
df_stats.write.parquet("./data/fhvhv_trips/weekly_dynamics", mode="overwrite")

                                                                                

In [59]:
# check everything was saved correctly and could be read:
spark.read.parquet("./data/fhvhv_trips/weekly_dynamics/").show()

+----------+--------+
|   d_month|trip_cnt|
+----------+--------+
|2021-06-28| 1377049|
|2021-05-31| 3061878|
|2021-06-14| 3488097|
|2021-06-07| 3524541|
|2021-06-21| 3506584|
+----------+--------+

