In [1]:
import pyspark
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.master("local[*]").appName("test").getOrCreate()

24/03/06 06:05:07 WARN Utils: Your hostname, 005482.local resolves to a loopback address: 127.0.0.1; using 192.168.1.38 instead (on interface en0)
24/03/06 06:05:07 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/03/06 06:05:07 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/03/06 06:05:07 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [3]:
service_type = "yellow"

In [4]:
df = spark.read.parquet(f"data/raw/{service_type}/*/*")
df.show()

                                                                                

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|airport_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+
|       1| 2020-01-01 00:28:15|  2020-01-01 00:33:03|            1.0|          1.2|       1.0|                 N|         238|         239|           1|        6.0|  3.0|    0.5|      1.4

In [5]:
df.columns
if "lpep_pickup_datetime" in df.columns:
    df = df.withColumnRenamed(
        "lpep_pickup_datetime", "pickup_datetime"
    ).withColumnRenamed("lpep_dropoff_datetime", "dropoff_datetime")
elif "tpep_pickup_datetime" in df.columns:
    df = df.withColumnRenamed(
        "tpep_pickup_datetime", "pickup_datetime"
    ).withColumnRenamed("tpep_dropoff_datetime", "dropoff_datetime")
df.columns

['VendorID',
 'pickup_datetime',
 'dropoff_datetime',
 'passenger_count',
 'trip_distance',
 'RatecodeID',
 'store_and_fwd_flag',
 'PULocationID',
 'DOLocationID',
 'payment_type',
 'fare_amount',
 'extra',
 'mta_tax',
 'tip_amount',
 'tolls_amount',
 'improvement_surcharge',
 'total_amount',
 'congestion_surcharge',
 'airport_fee']

In [6]:
df.createOrReplaceTempView("table")

In [7]:
revenue_hourly = spark.sql(
    """
SELECT 
    date_trunc('hour', pickup_datetime) AS hour,
    PULocationID as zone,
    
    SUM(total_amount) AS amount,
    COUNT(1) as record_count
FROM 
    table
GROUP BY 
    1,2
ORDER BY amount DESC
"""
)

In [8]:
revenue_hourly.show()



+-------------------+----+------------------+------------+
|               hour|zone|            amount|record_count|
+-------------------+----+------------------+------------+
|2020-03-10 09:00:00| 193|1000026.4500000001|           3|
|2020-10-07 10:00:00|  41|         998602.27|          21|
|2021-09-05 18:00:00| 141|  819787.340000001|         107|
|2020-03-04 17:00:00| 166| 673021.1200000009|         117|
|2020-05-04 20:00:00| 142|429733.24999999994|          12|
|2022-01-07 11:00:00| 107|402162.56999999995|          78|
|2021-03-18 12:00:00| 161|401589.70999999973|         228|
|2020-12-26 13:00:00| 170|399255.83999999997|          57|
|2021-04-10 13:00:00| 234|398795.42999999964|         184|
|2022-06-11 09:00:00| 163|397751.55999999953|         100|
|2022-09-24 17:00:00| 233|         189123.76|          92|
|2020-08-14 17:00:00| 142|188859.90999999983|          79|
|2020-11-17 06:00:00|  41|151590.91999999995|           8|
|2022-12-29 23:00:00| 132| 41493.56000000006|         56

                                                                                

In [9]:
revenue_hourly.coalesce(1).write.parquet(
    f"data/report/revenue/{service_type}", mode="overwrite"
)

                                                                                

In [10]:
green_hourly = spark.read.parquet("data/report/revenue/green/*")
yellow_hourly = spark.read.parquet("data/report/revenue/yellow/*")

In [11]:
from pyspark.sql import functions as F

df_green_type = green_hourly.withColumn("service_type", F.lit("green"))
df_yellow_type = yellow_hourly.withColumn("service_type", F.lit("yellow"))

In [12]:
green_hourly.createOrReplaceTempView("green_hourly")
yellow_hourly.createOrReplaceTempView("yellow_hourly")

In [13]:
revenue_hourly = spark.sql(
    """
SELECT 
    g.hour,
    g.zone,
    
    ROUND(g.amount+y.amount, 2) AS total_amount,
    g.record_count+y.record_count AS total_count,

    ROUND(g.amount, 2) green_amount,
    ROUND(y.amount, 2) yellow_amount,
    g.record_count green_count,
    y.record_count yellow_count
FROM 
    green_hourly g
JOIN
    yellow_hourly y
ON 
    g.hour = y.hour 
    AND g.zone = y.zone
ORDER BY total_amount DESC
"""
)

In [14]:
revenue_hourly.show()



+-------------------+----+------------+-----------+------------+-------------+-----------+------------+
|               hour|zone|total_amount|total_count|green_amount|yellow_amount|green_count|yellow_count|
+-------------------+----+------------+-----------+------------+-------------+-----------+------------+
|2020-03-10 09:00:00| 193|  1000097.68|          8|       71.23|   1000026.45|          5|           3|
|2020-10-07 10:00:00|  41|   998718.75|         31|      116.48|    998602.27|         10|          21|
|2020-03-04 17:00:00| 166|    673987.3|        177|      966.18|    673021.12|         60|         117|
|2020-11-17 06:00:00|  41|   151682.54|         13|       91.62|    151590.92|          5|           8|
|2020-01-26 20:00:00| 132|    38904.91|        709|         9.3|     38895.61|          1|         708|
|2020-01-27 16:00:00| 132|    35945.89|        589|         9.8|     35936.09|          1|         588|
|2020-03-01 20:00:00| 132|    34817.04|        623|       59.92|

                                                                                

In [15]:
green_hourly_tmp = green_hourly.withColumnRenamed(
    "amount", "green_amount"
).withColumnRenamed("record_count", "green_record_count")

yellow_hourly_tmp = yellow_hourly.withColumnRenamed(
    "amount", "yellow_amount"
).withColumnRenamed("record_count", "yellow_record_count")

In [16]:
df_join = green_hourly_tmp.join(yellow_hourly_tmp, on=["hour", "zone"], how="outer")

In [17]:
df_join.write.parquet("data/report/revenue/total", mode="overwrite")

24/03/06 06:05:28 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
24/03/06 06:05:28 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 84.44% for 9 writers
24/03/06 06:05:28 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 76.00% for 10 writers
24/03/06 06:05:28 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 69.09% for 11 writers
24/03/06 06:05:28 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 63.33% for 12 writers
24/03/06 06:05:28 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 58.46% for 13 writers
24/03/06 06:05:28 WARN MemoryManager: Total allocation exceeds 95.

In [18]:
df_join.show()

+-------------------+----+------------+------------------+-------------+-------------------+
|               hour|zone|green_amount|green_record_count|yellow_amount|yellow_record_count|
+-------------------+----+------------+------------------+-------------+-------------------+
|2001-01-01 01:00:00| 230|        NULL|              NULL|         24.8|                  1|
|2002-10-21 01:00:00| 230|        NULL|              NULL|         20.8|                  1|
|2002-10-21 18:00:00| 138|        NULL|              NULL|        260.3|                  1|
|2002-10-22 08:00:00| 226|        NULL|              NULL|         24.8|                  1|
|2002-10-22 19:00:00| 158|        NULL|              NULL|         13.0|                  1|
|2002-10-22 23:00:00| 186|        NULL|              NULL|         13.8|                  1|
|2002-10-23 00:00:00| 264|        NULL|              NULL|          6.8|                  1|
|2002-10-23 04:00:00|  79|        NULL|              NULL|         25.

In [19]:
df_zones = spark.read.csv(
    "data/zones/taxi_zone_lookup.csv", header=True, inferSchema=True
)

In [20]:
df_zones.write.parquet("data/zones/taxi_zone_lookup.parquet", mode="overwrite")

In [21]:
df_zones = spark.read.parquet("data/zones/taxi_zone_lookup.parquet").withColumnRenamed(
    "LocationID", "zone"
)

In [22]:
df_zones.show()

+----+-------------+--------------------+------------+
|zone|      Borough|                Zone|service_zone|
+----+-------------+--------------------+------------+
|   1|          EWR|      Newark Airport|         EWR|
|   2|       Queens|         Jamaica Bay|   Boro Zone|
|   3|        Bronx|Allerton/Pelham G...|   Boro Zone|
|   4|    Manhattan|       Alphabet City| Yellow Zone|
|   5|Staten Island|       Arden Heights|   Boro Zone|
|   6|Staten Island|Arrochar/Fort Wad...|   Boro Zone|
|   7|       Queens|             Astoria|   Boro Zone|
|   8|       Queens|        Astoria Park|   Boro Zone|
|   9|       Queens|          Auburndale|   Boro Zone|
|  10|       Queens|        Baisley Park|   Boro Zone|
|  11|     Brooklyn|          Bath Beach|   Boro Zone|
|  12|    Manhattan|        Battery Park| Yellow Zone|
|  13|    Manhattan|   Battery Park City| Yellow Zone|
|  14|     Brooklyn|           Bay Ridge|   Boro Zone|
|  15|       Queens|Bay Terrace/Fort ...|   Boro Zone|
|  16|    

In [23]:
df_join_zone = df_join.join(df_zones, on="zone", how="inner")

In [24]:
df_join_zone.show()

+----+-------------------+------------+------------------+-------------+-------------------+---------+--------------------+------------+
|zone|               hour|green_amount|green_record_count|yellow_amount|yellow_record_count|  Borough|                Zone|service_zone|
+----+-------------------+------------+------------------+-------------+-------------------+---------+--------------------+------------+
| 230|2001-01-01 01:00:00|        NULL|              NULL|         24.8|                  1|Manhattan|Times Sq/Theatre ...| Yellow Zone|
| 230|2002-10-21 01:00:00|        NULL|              NULL|         20.8|                  1|Manhattan|Times Sq/Theatre ...| Yellow Zone|
| 138|2002-10-21 18:00:00|        NULL|              NULL|        260.3|                  1|   Queens|   LaGuardia Airport|    Airports|
| 226|2002-10-22 08:00:00|        NULL|              NULL|         24.8|                  1|   Queens|           Sunnyside|   Boro Zone|
| 158|2002-10-22 19:00:00|        NULL|  

In [25]:
from datetime import datetime

In [26]:
rdd = df.select("pickup_datetime", "PULocationID", "total_amount").rdd

In [27]:
rdd.take(10)

[Row(pickup_datetime=datetime.datetime(2020, 1, 1, 0, 28, 15), PULocationID=238, total_amount=11.27),
 Row(pickup_datetime=datetime.datetime(2020, 1, 1, 0, 35, 39), PULocationID=239, total_amount=12.3),
 Row(pickup_datetime=datetime.datetime(2020, 1, 1, 0, 47, 41), PULocationID=238, total_amount=10.8),
 Row(pickup_datetime=datetime.datetime(2020, 1, 1, 0, 55, 23), PULocationID=238, total_amount=8.16),
 Row(pickup_datetime=datetime.datetime(2020, 1, 1, 0, 1, 58), PULocationID=193, total_amount=4.8),
 Row(pickup_datetime=datetime.datetime(2020, 1, 1, 0, 9, 44), PULocationID=7, total_amount=3.8),
 Row(pickup_datetime=datetime.datetime(2020, 1, 1, 0, 39, 25), PULocationID=193, total_amount=3.81),
 Row(pickup_datetime=datetime.datetime(2019, 12, 18, 15, 27, 49), PULocationID=193, total_amount=2.81),
 Row(pickup_datetime=datetime.datetime(2019, 12, 18, 15, 30, 35), PULocationID=193, total_amount=6.3),
 Row(pickup_datetime=datetime.datetime(2020, 1, 1, 0, 29, 1), PULocationID=246, total_amoun

In [28]:
start = datetime(year=2020, month=1, day=1)


def filter_outliers(row):
    return row.pickup_datetime >= start

In [29]:
rows = rdd.take(10)
row = rows[0]

row

Row(pickup_datetime=datetime.datetime(2020, 1, 1, 0, 28, 15), PULocationID=238, total_amount=11.27)

In [30]:
def prepare_for_grouping(row):
    hour = row.pickup_datetime.replace(minute=0, second=0, microsecond=0)
    zone = row.PULocationID
    key = (hour, zone)

    amount = row.total_amount
    count = 1
    value = (amount, count)

    return (key, value)

In [31]:
prepare_for_grouping(row)

((datetime.datetime(2020, 1, 1, 0, 0), 238), (11.27, 1))

In [32]:
def calculate_revenue(left_value, right_value):
    left_amount, left_count = left_value
    right_amount, right_count = right_value

    output_amount = left_amount + right_amount
    output_count = left_count + right_count

    return (output_amount, output_count)

In [33]:
from collections import namedtuple

RevenueRow = namedtuple("RevenueRow", ["hour", "zone", "revenue", "count"])

In [34]:
def unwrap(row):
    return RevenueRow(
        hour=row[0][0], zone=row[0][1], revenue=row[1][0], count=row[1][1]
    )

In [35]:
from pyspark.sql import types

In [36]:
result_schema = types.StructType(
    [
        types.StructField("hour", types.TimestampType(), True),
        types.StructField("zone", types.IntegerType(), True),
        types.StructField("revenue", types.DoubleType(), True),
        types.StructField("count", types.IntegerType(), True),
    ]
)

In [37]:
df_result = (
    rdd.filter(filter_outliers)
    .map(prepare_for_grouping)
    .reduceByKey(calculate_revenue)
    .map(unwrap)
    .toDF(result_schema)
)

In [38]:
df_result.schema

StructType([StructField('hour', TimestampType(), True), StructField('zone', IntegerType(), True), StructField('revenue', DoubleType(), True), StructField('count', IntegerType(), True)])

In [39]:
df_result.show()



+-------------------+----+------------------+-----+
|               hour|zone|           revenue|count|
+-------------------+----+------------------+-----+
|2020-01-01 00:00:00| 236| 5254.590000000019|  339|
|2020-01-01 00:00:00| 142|  9252.30000000001|  488|
|2020-01-01 00:00:00|  90|  5010.45000000001|  266|
|2020-01-01 00:00:00| 260|             74.14|    8|
|2020-01-01 01:00:00|  90| 4326.070000000006|  231|
|2020-01-01 00:00:00|  60|57.620000000000005|    2|
|2020-01-01 01:00:00| 234| 6217.040000000017|  347|
|2020-01-01 01:00:00| 134|              9.75|    1|
|2020-01-01 02:00:00|  75|1314.5599999999984|   85|
|2020-01-01 02:00:00| 249| 6182.610000000019|  299|
|2020-01-01 01:00:00| 228|             74.07|    2|
|2020-01-01 02:00:00|  41|1480.5899999999983|   91|
|2020-01-01 02:00:00| 145| 596.0499999999998|   32|
|2020-01-01 02:00:00| 125|1753.1899999999994|   74|
|2020-01-01 03:00:00| 137|2045.3099999999977|  116|
|2020-01-01 02:00:00| 179| 300.6100000000002|   23|
|2020-01-01 

                                                                                

In [40]:
df_result.write.parquet("tmp/value", mode="overwrite")

24/03/06 06:06:34 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
24/03/06 06:06:34 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 84.44% for 9 writers
24/03/06 06:06:34 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 76.00% for 10 writers
24/03/06 06:06:34 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 69.09% for 11 writers
24/03/06 06:06:34 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 63.33% for 12 writers
24/03/06 06:06:34 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 58.46% for 13 writers
24/03/06 06:06:34 WARN MemoryManager: Total allocation exceeds 95.

In [43]:
columns = [
    "VendorID",
    "pickup_datetime",
    "PULocationID",
    "DOLocationID",
    "trip_distance",
]
service_type = "green"
df = spark.read.parquet(f"data/raw/{service_type}/*/*")
df.columns
if "lpep_pickup_datetime" in df.columns:
    df = df.withColumnRenamed(
        "lpep_pickup_datetime", "pickup_datetime"
    ).withColumnRenamed("lpep_dropoff_datetime", "dropoff_datetime")
elif "tpep_pickup_datetime" in df.columns:
    df = df.withColumnRenamed(
        "tpep_pickup_datetime", "pickup_datetime"
    ).withColumnRenamed("tpep_dropoff_datetime", "dropoff_datetime")

duration_rdd = df.select(columns).rdd

In [44]:
duration_rdd.take(10)

[Row(VendorID=2, pickup_datetime=datetime.datetime(2019, 12, 18, 15, 52, 30), PULocationID=264, DOLocationID=264, trip_distance=0.0),
 Row(VendorID=2, pickup_datetime=datetime.datetime(2020, 1, 1, 0, 45, 58), PULocationID=66, DOLocationID=65, trip_distance=1.28),
 Row(VendorID=2, pickup_datetime=datetime.datetime(2020, 1, 1, 0, 41, 38), PULocationID=181, DOLocationID=228, trip_distance=2.47),
 Row(VendorID=1, pickup_datetime=datetime.datetime(2020, 1, 1, 0, 52, 46), PULocationID=129, DOLocationID=263, trip_distance=6.3),
 Row(VendorID=1, pickup_datetime=datetime.datetime(2020, 1, 1, 0, 19, 57), PULocationID=210, DOLocationID=150, trip_distance=2.3),
 Row(VendorID=1, pickup_datetime=datetime.datetime(2020, 1, 1, 0, 52, 33), PULocationID=35, DOLocationID=39, trip_distance=3.0),
 Row(VendorID=2, pickup_datetime=datetime.datetime(2020, 1, 1, 0, 10, 18), PULocationID=25, DOLocationID=61, trip_distance=2.77),
 Row(VendorID=2, pickup_datetime=datetime.datetime(2020, 1, 1, 1, 3, 14), PULocatio

In [45]:
def infinite_seq():
    i = 0
    while True:
        yield i
        i += 1

In [None]:
seq = infinite_seq()

In [None]:
for i in seq:
    print(i)

    if i > 10:
        break

0
1
2
3
4
5
6
7
8
9
10
11


In [46]:
import pandas as pd


def model_predict(df):
    # y_pred = model.predict(df)
    y_pred = df.trip_distance * 5
    return y_pred


def apply_model_in_batch(partition):
    df = pd.DataFrame(rows, columns=columns)
    predictions = model_predict(df)
    df["predicted_duration"] = predictions

    for row in df.itertuples():
        yield row

In [47]:
duration_rdd.mapPartitions(apply_model_in_batch).toDF().take(10)

24/03/06 06:10:03 ERROR Executor: Exception in task 0.0 in stage 50.0 (TID 307)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pandas/core/internals/construction.py", line 939, in _finalize_columns_and_data
    columns = _validate_or_indexify_columns(contents, columns)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pandas/core/internals/construction.py", line 986, in _validate_or_indexify_columns
    raise AssertionError(
AssertionError: 5 columns passed, passed data had 3 columns

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pyspark/python/lib/pyspark.zip/pyspark/wo

Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.runJob.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 50.0 failed 1 times, most recent failure: Lost task 0.0 in stage 50.0 (TID 307) (192.168.1.38 executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pandas/core/internals/construction.py", line 939, in _finalize_columns_and_data
    columns = _validate_or_indexify_columns(contents, columns)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pandas/core/internals/construction.py", line 986, in _validate_or_indexify_columns
    raise AssertionError(
AssertionError: 5 columns passed, passed data had 3 columns

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1247, in main
    process()
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1239, in process
    serializer.dump_stream(out_iter, outfile)
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 274, in dump_stream
    vs = list(itertools.islice(iterator, batch))
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pyspark/rdd.py", line 2849, in takeUpToNumLeft
    yield next(iterator)
          ^^^^^^^^^^^^^^
  File "/var/folders/wp/fzyspr3j2lq6vw4th77_dm4h0000gq/T/ipykernel_41277/2195418199.py", line 11, in apply_model_in_batch
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pandas/core/frame.py", line 840, in __init__
    arrays, columns, index = nested_data_to_arrays(
                             ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pandas/core/internals/construction.py", line 520, in nested_data_to_arrays
    arrays, columns = to_arrays(data, columns, dtype=dtype)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pandas/core/internals/construction.py", line 845, in to_arrays
    content, columns = _finalize_columns_and_data(arr, columns, dtype)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pandas/core/internals/construction.py", line 942, in _finalize_columns_and_data
    raise ValueError(err) from err
ValueError: 5 columns passed, passed data had 3 columns

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:784)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:766)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.api.python.PythonRDD$.$anonfun$runJob$1(PythonRDD.scala:181)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2438)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:829)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2398)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2419)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2438)
	at org.apache.spark.api.python.PythonRDD$.runJob(PythonRDD.scala:181)
	at org.apache.spark.api.python.PythonRDD.runJob(PythonRDD.scala)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pandas/core/internals/construction.py", line 939, in _finalize_columns_and_data
    columns = _validate_or_indexify_columns(contents, columns)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pandas/core/internals/construction.py", line 986, in _validate_or_indexify_columns
    raise AssertionError(
AssertionError: 5 columns passed, passed data had 3 columns

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1247, in main
    process()
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 1239, in process
    serializer.dump_stream(out_iter, outfile)
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 274, in dump_stream
    vs = list(itertools.islice(iterator, batch))
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pyspark/rdd.py", line 2849, in takeUpToNumLeft
    yield next(iterator)
          ^^^^^^^^^^^^^^
  File "/var/folders/wp/fzyspr3j2lq6vw4th77_dm4h0000gq/T/ipykernel_41277/2195418199.py", line 11, in apply_model_in_batch
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pandas/core/frame.py", line 840, in __init__
    arrays, columns, index = nested_data_to_arrays(
                             ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pandas/core/internals/construction.py", line 520, in nested_data_to_arrays
    arrays, columns = to_arrays(data, columns, dtype=dtype)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pandas/core/internals/construction.py", line 845, in to_arrays
    content, columns = _finalize_columns_and_data(arr, columns, dtype)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jonathondonager/Documents/GitHub/data-engineering-learn/.venv/lib/python3.11/site-packages/pandas/core/internals/construction.py", line 942, in _finalize_columns_and_data
    raise ValueError(err) from err
ValueError: 5 columns passed, passed data had 3 columns

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:572)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:784)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:766)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.InterruptibleIterator.foreach(InterruptibleIterator.scala:28)
	at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
	at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:105)
	at scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:49)
	at scala.collection.TraversableOnce.to(TraversableOnce.scala:366)
	at scala.collection.TraversableOnce.to$(TraversableOnce.scala:364)
	at org.apache.spark.InterruptibleIterator.to(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toBuffer(TraversableOnce.scala:358)
	at scala.collection.TraversableOnce.toBuffer$(TraversableOnce.scala:358)
	at org.apache.spark.InterruptibleIterator.toBuffer(InterruptibleIterator.scala:28)
	at scala.collection.TraversableOnce.toArray(TraversableOnce.scala:345)
	at scala.collection.TraversableOnce.toArray$(TraversableOnce.scala:339)
	at org.apache.spark.InterruptibleIterator.toArray(InterruptibleIterator.scala:28)
	at org.apache.spark.api.python.PythonRDD$.$anonfun$runJob$1(PythonRDD.scala:181)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2438)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	... 1 more
