In [2]:
from pyspark.sql import SparkSession
import pyspark.sql.types as st
import pyspark.sql.functions as sf

In [3]:
spark = (
    SparkSession
    .builder
    .appName('PySpark tutorial')
    .master('local[*]')
    .getOrCreate()
)

In [4]:
data = [
    (1, 'Alice', 20),
    (2, 'Bob', 25),
    (3, 'Charlie', 30),
    (4, 'David', 40)
]

columns = ['id', 'name', 'age']

df = spark.createDataFrame(data, columns)

df.show()

+---+-------+---+
| id|   name|age|
+---+-------+---+
|  1|  Alice| 20|
|  2|    Bob| 25|
|  3|Charlie| 30|
|  4|  David| 40|
+---+-------+---+



In [5]:
df.printSchema()

root
 |-- id: long (nullable = true)
 |-- name: string (nullable = true)
 |-- age: long (nullable = true)



In [6]:
schema = st.StructType(
    [
        st.StructField("id", st.IntegerType(), True),
        st.StructField("name", st.StringType(), True),
        st.StructField("age", st.IntegerType(), True)
    ]
)

df = spark.createDataFrame(data, schema=schema)

df.printSchema()

root
 |-- id: integer (nullable = true)
 |-- name: string (nullable = true)
 |-- age: integer (nullable = true)



#### Read / Write

In [7]:
df_read_csv = spark.read.csv('../cubix_de_pyspark/data/d18.csv', header=True, sep=';')

df_read_csv.show()

+---------+------+-----+------+
|  aquifer|  d18O|  T_C|d18O_2|
+---------+------+-----+------+
|   Dacian|-12.14|11.17|-12.14|
|   Dacian|-10.11|10.57|-10.11|
|   Dacian|-12.44| 9.97|-13.44|
|   Dacian|-12.20| 9.95|-12.70|
|  Meotian|-10.10|13.37|-10.10|
|  Pontian|-14.24|10.37|-15.24|
|  Pontian|-11.08| 8.13|-11.08|
|  Pontian|-12.07| 8.51|-13.07|
|  Pontian|-11.32| 8.25|-11.32|
|  Pontian|-10.71|13.42|-10.71|
|  Pontian|-12.14| 8.38|-12.64|
|  Pontian| -9.82|10.98| -9.82|
|Sarmatian| -9.87|11.27| -9.87|
|Sarmatian|-11.77| 8.23|-12.77|
+---------+------+-----+------+



same as above, but in different format to parametrize

In [8]:
df_read_csv = (
    spark
    .read
    .format('csv')
    .option('header', True)
    .option('sep', ';')
    .load('../cubix_de_pyspark/data/d18.csv')
)

df_read_csv.show()

+---------+------+-----+------+
|  aquifer|  d18O|  T_C|d18O_2|
+---------+------+-----+------+
|   Dacian|-12.14|11.17|-12.14|
|   Dacian|-10.11|10.57|-10.11|
|   Dacian|-12.44| 9.97|-13.44|
|   Dacian|-12.20| 9.95|-12.70|
|  Meotian|-10.10|13.37|-10.10|
|  Pontian|-14.24|10.37|-15.24|
|  Pontian|-11.08| 8.13|-11.08|
|  Pontian|-12.07| 8.51|-13.07|
|  Pontian|-11.32| 8.25|-11.32|
|  Pontian|-10.71|13.42|-10.71|
|  Pontian|-12.14| 8.38|-12.64|
|  Pontian| -9.82|10.98| -9.82|
|Sarmatian| -9.87|11.27| -9.87|
|Sarmatian|-11.77| 8.23|-12.77|
+---------+------+-----+------+



In [9]:
df_read_parquet = spark.read.parquet('../cubix_de_pyspark/data/yellow_tripdata_2024-09.parquet')

df_read_parquet.show(10)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:05:51|  2024-09-01 00:45:03|              1|          9.8|         1|                 N|         138|          48|           1|       47.8|10.25|    0.5|      13.

In [10]:
df_read_parquet = (
    spark
    .read
    .format('parquet')
    .load('../cubix_de_pyspark/data/yellow_tripdata_2024-09.parquet')
)

df_read_parquet.show(4)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:05:51|  2024-09-01 00:45:03|              1|          9.8|         1|                 N|         138|          48|           1|       47.8|10.25|    0.5|      13.

In [11]:
df_read_parquet.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)



In [12]:
df_read_parquet.count()

3633030

In [1]:
## doesn't work locally
# df_read_parquet.write.parquet('../cubix_de_pyspark/data/yellow_tripdata.parquet')

#### Transformation
##### Selecting

In [14]:
taxi_df = (
    spark
    .read
    .format('parquet')
    .load('../cubix_de_pyspark/data/yellow_tripdata_2024-09.parquet')
)

In [None]:
taxi_df.describe().show()

+-------+------------------+------------------+-----------------+------------------+------------------+-----------------+-----------------+------------------+------------------+------------------+-------------------+-----------------+------------------+---------------------+------------------+--------------------+-------------------+
|summary|          VendorID|   passenger_count|    trip_distance|        RatecodeID|store_and_fwd_flag|     PULocationID|     DOLocationID|      payment_type|       fare_amount|             extra|            mta_tax|       tip_amount|      tolls_amount|improvement_surcharge|      total_amount|congestion_surcharge|        Airport_fee|
+-------+------------------+------------------+-----------------+------------------+------------------+-----------------+-----------------+------------------+------------------+------------------+-------------------+-----------------+------------------+---------------------+------------------+--------------------+-------------

In [None]:
taxi_df.select('VendorID', 'passenger_count').show(5)

+--------+---------------+
|VendorID|passenger_count|
+--------+---------------+
|       1|              1|
|       1|              1|
|       2|              2|
|       2|              1|
|       2|              2|
+--------+---------------+
only showing top 5 rows



In [None]:
taxi_df.select(taxi_df.VendorID, taxi_df.passenger_count).show(5)

+--------+---------------+
|VendorID|passenger_count|
+--------+---------------+
|       1|              1|
|       1|              1|
|       2|              2|
|       2|              1|
|       2|              2|
+--------+---------------+
only showing top 5 rows



In [None]:
taxi_df.select(sf.col('VendorID'), sf.col('passenger_count')).show(5)

+--------+---------------+
|VendorID|passenger_count|
+--------+---------------+
|       1|              1|
|       1|              1|
|       2|              2|
|       2|              1|
|       2|              2|
+--------+---------------+
only showing top 5 rows



github.com/palantir/pyspark-style-guide

- third method is better, except for joins, when join columns have the same name, then writing df.col is better

##### Filtering

In [None]:
taxi_df.where(sf.col('trip_distance') > 5).show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:05:51|  2024-09-01 00:45:03|              1|          9.8|         1|                 N|         138|          48|           1|       47.8|10.25|    0.5|      13.

In [None]:
taxi_df.where((sf.col('trip_distance') > 5) & (sf.col('passenger_count') > 1)).show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:08:28|  2024-09-01 00:39:06|              4|          9.8|         1|                 N|          93|         161|           1|       44.3|  3.5|    0.5|      9.8

when code is wrapped inside () -> we can use linebreaks without \

In [None]:
(
    taxi_df.where(
    (sf.col('trip_distance') > 5) & (sf.col('passenger_count') > 1)
    )
.show(5)
)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:08:28|  2024-09-01 00:39:06|              4|          9.8|         1|                 N|          93|         161|           1|       44.3|  3.5|    0.5|      9.8

In [None]:
taxi_df.where(sf.col('store_and_fwd_flag') == 'N').show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:05:51|  2024-09-01 00:45:03|              1|          9.8|         1|                 N|         138|          48|           1|       47.8|10.25|    0.5|      13.

negate with != or with ~()

In [None]:
(
    taxi_df
    .where(
        ~(sf.col('store_and_fwd_flag') == 'N'))
        .show(5)
)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:23:40|  2024-09-01 00:43:46|              1|          3.9|         1|                 Y|         142|          42|           1|       20.5|  3.5|    0.5|      6.3

##### withColumn

In [None]:
taxi_df = (
    taxi_df
    .withColumn(
        'total_amount_with_all_tax',
        sf.col('total_amount') + sf.col('congestion_surcharge') + sf.col('Airport_fee')
    )
)

(
    taxi_df
    .select(
        sf.col('total_amount'),
        sf.col('Airport_fee'),
        sf.col('congestion_surcharge'),
        sf.col('total_amount_with_all_tax')
    )
).show(5)

+------------+-----------+--------------------+-------------------------+
|total_amount|Airport_fee|congestion_surcharge|total_amount_with_all_tax|
+------------+-----------+--------------------+-------------------------+
|       79.79|       1.75|                 2.5|                    84.04|
|        13.1|        0.0|                 2.5|                     15.6|
|        16.0|        0.0|                 0.0|                     16.0|
|       31.75|        0.0|                 0.0|                    31.75|
|        26.4|        0.0|                 2.5|                     28.9|
+------------+-----------+--------------------+-------------------------+
only showing top 5 rows



##### withColumnRenamed

In [None]:
taxi_df.withColumnRenamed('total_amount_with_all_tax', 'total_amount_with_taxes')
taxi_df.show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-------------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|total_amount_with_all_tax|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-------------------------+
|       1| 2024-09-01 00:05:51|  2024-09-01 00:45:03|              1|          9.8|         1|               

spark df is immutable, so we have to reassign df to itself, there is no inplace=True option

In [None]:
taxi_df = taxi_df.withColumnRenamed('total_amount_with_all_tax', 'total_amount_with_taxes')
taxi_df.show(5)

##### Drop

In [None]:
taxi_df = taxi_df.drop(sf.col('total_amount_with_taxes'))
taxi_df.show(2)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-------------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|total_amount_with_all_tax|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+-------------------------+
|       1| 2024-09-01 00:05:51|  2024-09-01 00:45:03|              1|          9.8|         1|               

##### Group by (order by)

In [None]:
(
    taxi_df
    .groupBy(sf.col('PULocationID'))
    .count()
    .sort(sf.col('count'), ascending=False)
    .show(10)
)

+------------+------+
|PULocationID| count|
+------------+------+
|         132|188147|
|         237|166033|
|         161|153988|
|         236|149284|
|         186|120393|
|         162|117163|
|         230|113531|
|         138|111112|
|         142|107537|
|          68|104585|
+------------+------+
only showing top 10 rows



In [None]:
(
    taxi_df
    .groupBy(sf.col('payment_type'))
    .agg(
        sf.sum('total_amount').alias('total_amount'),
        sf.avg('total_amount').alias('average_amount')
    )
    .sort(sf.col('total_amount'), ascending=False)
    .withColumn('total_amount', sf.format_number(sf.col('total_amount'), 2))
    .withColumn('average_amount', sf.format_number(sf.col('average_amount'), 2))
    .show(5)
)

+------------+-------------+--------------+
|payment_type| total_amount|average_amount|
+------------+-------------+--------------+
|           1|81,090,276.01|         31.13|
|           0|11,577,187.59|         23.93|
|           2|10,682,143.65|         23.99|
|           3|   198,527.68|          7.84|
|           4|   135,263.93|          1.83|
+------------+-------------+--------------+



: 

In [15]:
taxi_df.orderBy(sf.col('payment_type')).show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:47:27|  2024-09-01 00:54:09|           NULL|          1.1|      NULL|              NULL|         170|         234|           0|      20.29|  0.0|    0.5|       0.

sf.format_number turns integer to string!

#### Joins

In [19]:
# DF1: employees

employees_data = [
    (1, 'Alice', 101),
    (2, 'Bob', 102),
    (3, 'Charlie', 103),
    (4, 'David', 101),
    (5, 'Eve', 104),
    (6, 'Frank', 105),
    (7, 'Grace', 102),
    (8, 'Helen', 106),
    (9, 'Ian', 103),
    (10, 'Jack', 104)
]

employees_columns = ['employee_id', 'name', 'department_id']
employees_df = spark.createDataFrame(employees_data, employees_columns)

# DF2: departments

departments_data = [
    (101, 'HR'),
    (102, 'Finance'),
    (103, 'IT'),
    (104, 'Marketing'),
    (105, 'Sales'),
    (107, 'Operations')
]

departments_columns = ['department_id', 'department_name']
departments_df = spark.createDataFrame(departments_data, departments_columns)

# some programs doesn't print every .show() just the first, so sometimes print() is needed to be wrapped around show()
employees_df.show()
departments_df.show()

+-----------+-------+-------------+
|employee_id|   name|department_id|
+-----------+-------+-------------+
|          1|  Alice|          101|
|          2|    Bob|          102|
|          3|Charlie|          103|
|          4|  David|          101|
|          5|    Eve|          104|
|          6|  Frank|          105|
|          7|  Grace|          102|
|          8|  Helen|          106|
|          9|    Ian|          103|
|         10|   Jack|          104|
+-----------+-------+-------------+

+-------------+---------------+
|department_id|department_name|
+-------------+---------------+
|          101|             HR|
|          102|        Finance|
|          103|             IT|
|          104|      Marketing|
|          105|          Sales|
|          107|     Operations|
+-------------+---------------+



In [20]:
inner_join_df = (
    employees_df
    .join(departments_df, 'department_id')
    .sort(sf.col('department_id'))
)

inner_join_df.show()

+-------------+-----------+-------+---------------+
|department_id|employee_id|   name|department_name|
+-------------+-----------+-------+---------------+
|          101|          1|  Alice|             HR|
|          101|          4|  David|             HR|
|          102|          2|    Bob|        Finance|
|          102|          7|  Grace|        Finance|
|          103|          3|Charlie|             IT|
|          103|          9|    Ian|             IT|
|          104|          5|    Eve|      Marketing|
|          104|         10|   Jack|      Marketing|
|          105|          6|  Frank|          Sales|
+-------------+-----------+-------+---------------+



In [21]:
left_join_df = (
    employees_df
    .join(departments_df, 'department_id', how='left')
    .sort(sf.col('department_id'))
)

left_join_df.show()

+-------------+-----------+-------+---------------+
|department_id|employee_id|   name|department_name|
+-------------+-----------+-------+---------------+
|          101|          1|  Alice|             HR|
|          101|          4|  David|             HR|
|          102|          2|    Bob|        Finance|
|          102|          7|  Grace|        Finance|
|          103|          3|Charlie|             IT|
|          103|          9|    Ian|             IT|
|          104|          5|    Eve|      Marketing|
|          104|         10|   Jack|      Marketing|
|          105|          6|  Frank|          Sales|
|          106|          8|  Helen|           NULL|
+-------------+-----------+-------+---------------+



In [22]:
full_outer_join_df = (
    employees_df
    .join(departments_df, 'department_id', how='outer')
    .sort(sf.col('department_id'))
)

full_outer_join_df.show()

+-------------+-----------+-------+---------------+
|department_id|employee_id|   name|department_name|
+-------------+-----------+-------+---------------+
|          101|          1|  Alice|             HR|
|          101|          4|  David|             HR|
|          102|          2|    Bob|        Finance|
|          102|          7|  Grace|        Finance|
|          103|          3|Charlie|             IT|
|          103|          9|    Ian|             IT|
|          104|          5|    Eve|      Marketing|
|          104|         10|   Jack|      Marketing|
|          105|          6|  Frank|          Sales|
|          106|          8|  Helen|           NULL|
|          107|       NULL|   NULL|     Operations|
+-------------+-----------+-------+---------------+



#### SparkSQL

In [3]:
taxi_df = (
    spark
    .read
    .format('parquet')
    .load('../cubix_de_pyspark/data/yellow_tripdata_2024-09.parquet')
)

In [4]:
taxi_df.createOrReplaceTempView('taxi_temp_view')

In [5]:
spark.sql('select * from taxi_temp_view limit 10').show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:05:51|  2024-09-01 00:45:03|              1|          9.8|         1|                 N|         138|          48|           1|       47.8|10.25|    0.5|      13.

In [6]:
spark.sql("""
    select
          *
    from
          taxi_temp_view
    where
          trip_distance > 10
    limit 10
""").show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2024-09-01 00:57:30|  2024-09-01 01:40:20|              2|         14.5|         1|                 N|         132|          91|           1|       64.6| 2.75|    0.5|     13.7

In [7]:
avg_amount_by_payment_type = spark.sql(
    """
    select
          payment_type,
          avg(total_amount)
    from
          taxi_temp_view
    group by
        payment_type
"""
)

avg_amount_by_payment_type.show()

+------------+------------------+
|payment_type| avg(total_amount)|
+------------+------------------+
|           1|31.132636716912728|
|           3| 7.835484864032835|
|           2|23.992416616129827|
|           4|1.8263854120252212|
|           0|23.933110737169443|
+------------+------------------+



createOrReplaceTempView() or any other view creation only supports SELECTing, no UPDATE/INSERT/DELETE with parquet files, only with delta

#### Caching

In [4]:
import time

In [7]:
data = [(i, i *2) for i in range(100000)]
df = spark.createDataFrame(data, ['number', 'double'])
df.show(5)

start_time = time.time()

# repeat an expensive computation twice
df_filtered = df.filter(df['number'] % 2 == 0)
print(df_filtered.count()) # first computation
print(df_filtered.count()) # repeating computation

end_time = time.time()
print(f'Time taken without caching: {end_time-start_time:.2f} seconds')

+------+------+
|number|double|
+------+------+
|     0|     0|
|     1|     2|
|     2|     4|
|     3|     6|
|     4|     8|
+------+------+
only showing top 5 rows

50000
50000
Time taken without caching: 14.29 seconds


In [8]:
data = [(i, i *2) for i in range(100000)]
df = spark.createDataFrame(data, ['number', 'double'])
df.show(5)

start_time = time.time()

# repeat an expensive computation twice
df_filtered = df.filter(df['number'] % 2 == 0).cache()
print(df_filtered.count()) # first computation
print(df_filtered.count()) # repeating computation

end_time = time.time()
print(f'Time taken with caching: {end_time-start_time:.2f} seconds')

+------+------+
|number|double|
+------+------+
|     0|     0|
|     1|     2|
|     2|     4|
|     3|     6|
|     4|     8|
+------+------+
only showing top 5 rows

50000
50000
Time taken with caching: 7.55 seconds


#### Broadcast

In [9]:
product_ids = [i for i in range(1, 3000001)]
product_names = [f'Product_{i}' for i in range(1, 3000001)]

large_data = list(zip(product_ids, product_names))
columns = ['product_id', 'product_name']

large_df = spark.createDataFrame(large_data, columns)
large_df.show(5)

category_data = [(1, 'Electronics'), (2, 'Clothing'), (3, 'Toys'), (4, 'Groceries'), (5, 'Books')]
category_columns = ['product_id', 'category']

small_df = spark.createDataFrame(category_data, category_columns)
small_df.show(5)

+----------+------------+
|product_id|product_name|
+----------+------------+
|         1|   Product_1|
|         2|   Product_2|
|         3|   Product_3|
|         4|   Product_4|
|         5|   Product_5|
+----------+------------+
only showing top 5 rows

+----------+-----------+
|product_id|   category|
+----------+-----------+
|         1|Electronics|
|         2|   Clothing|
|         3|       Toys|
|         4|  Groceries|
|         5|      Books|
+----------+-----------+



In [10]:
start_time = time.time()

joined_df_no_broadcast = large_df.join(small_df, 'product_id')
joined_df_no_broadcast.show(5)

end_time = time.time()
print(f'Time taken: {end_time - start_time:.2f} seconds')

+----------+------------+-----------+
|product_id|product_name|   category|
+----------+------------+-----------+
|         1|   Product_1|Electronics|
|         5|   Product_5|      Books|
|         3|   Product_3|       Toys|
|         2|   Product_2|   Clothing|
|         4|   Product_4|  Groceries|
+----------+------------+-----------+

Time taken: 19.77 seconds


In [11]:
start_time = time.time()

# perform join with broadcasting the small df
joined_df_no_broadcast = large_df.join(sf.broadcast(small_df), 'product_id')
joined_df_no_broadcast.show(5)

end_time = time.time()
print(f'Time taken: {end_time - start_time:.2f} seconds')

+----------+------------+-----------+
|product_id|product_name|   category|
+----------+------------+-----------+
|         1|   Product_1|Electronics|
|         2|   Product_2|   Clothing|
|         3|   Product_3|       Toys|
|         4|   Product_4|  Groceries|
|         5|   Product_5|      Books|
+----------+------------+-----------+

Time taken: 16.71 seconds


In [12]:
print('Without broadcasting:')
print(large_df.join(small_df, 'product_id').explain())

print('With broadcasting:')
print(large_df.join(sf.broadcast(small_df), 'product_id').explain())

Without broadcasting:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [product_id#218L, product_name#219, category#232]
   +- SortMergeJoin [product_id#218L], [product_id#231L], Inner
      :- Sort [product_id#218L ASC NULLS FIRST], false, 0
      :  +- Exchange hashpartitioning(product_id#218L, 200), ENSURE_REQUIREMENTS, [plan_id=603]
      :     +- Filter isnotnull(product_id#218L)
      :        +- Scan ExistingRDD[product_id#218L,product_name#219]
      +- Sort [product_id#231L ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(product_id#231L, 200), ENSURE_REQUIREMENTS, [plan_id=604]
            +- Filter isnotnull(product_id#231L)
               +- Scan ExistingRDD[product_id#231L,category#232]


None
With broadcasting:
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [product_id#218L, product_name#219, category#232]
   +- BroadcastHashJoin [product_id#218L], [product_id#231L], Inner, BuildRight, false
      :- Filter isnotnull(p