In [2]:
%%info

ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
3735,application_1713270977862_4021,pyspark,idle,Link,Link,,
3824,application_1713270977862_4118,pyspark,idle,Link,Link,,
3829,application_1713270977862_4123,pyspark,idle,Link,Link,,
3830,application_1713270977862_4124,pyspark,idle,Link,Link,,
3831,application_1713270977862_4125,pyspark,idle,Link,Link,,✔


In [3]:
print(f'Start Spark name:{spark._sc.appName}, version:{spark.version}')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Start Spark name:livy-session-3831, version:3.3.2.3.3.7190.2-1

In [4]:
%%local
import os
default_db = 'com490'
username=os.getenv('USER', 'anonymous')
hadoop_fs=os.getenv('HADOOP_DEFAULT_FS', 'hdfs://iccluster067.iccluster.epfl.ch:8020')
print(f"local username={username}\nhadoop_fs={hadoop_fs}")

local username=kvaerum
hadoop_fs=hdfs://iccluster067.iccluster.epfl.ch:8020


### Data Retrieval

In [65]:
from pyspark import SparkContext
sc = SparkContext.getOrCreate()

# List files in the directory
files = sc._jvm.org.apache.hadoop.fs.FileSystem.get(sc._jsc.hadoopConfiguration()) \
    .listStatus(sc._jvm.org.apache.hadoop.fs.Path('/data/sbb/csv/timetables/stops'))
for file_status in files:
    print(file_status.getPath())

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

hdfs://iccluster067.iccluster.epfl.ch:8020/data/sbb/csv/timetables/stops/year=2020
hdfs://iccluster067.iccluster.epfl.ch:8020/data/sbb/csv/timetables/stops/year=2021
hdfs://iccluster067.iccluster.epfl.ch:8020/data/sbb/csv/timetables/stops/year=2022
hdfs://iccluster067.iccluster.epfl.ch:8020/data/sbb/csv/timetables/stops/year=2023
hdfs://iccluster067.iccluster.epfl.ch:8020/data/sbb/csv/timetables/stops/year=2024

## Istdaten Processing

### 1) Select Relevant Columns and Rename for Convenience

In [137]:
istdaten_df = spark.read.orc('/data/sbb/orc/istdaten/year=2023')

istdaten_df = istdaten_df.select([
  istdaten_df['betriebstag'].alias('date'),
  istdaten_df['fahrt_bezeichner'].alias('trip_id'),
  istdaten_df['betreiber_id'].alias('operator_id'),
  istdaten_df['bpuic'].alias('stop_id'),
  istdaten_df['produkt_id'].alias('transportation'),
  istdaten_df['faellt_aus_tf'].alias('cancelled'),
  istdaten_df['ankunftszeit'].alias('arrival_time'),
  istdaten_df['an_prognose'].alias('arrival_prognosis'),
  istdaten_df['an_prognose_status'].alias('arr_prog_status'),
  istdaten_df['abfahrtszeit'].alias('departure_time'),
  istdaten_df['ab_prognose'].alias('departure_prognosis'),
  istdaten_df['ab_prognose_status'].alias('dep_prog_status'),
])

istdaten_df.show(5)
istdaten_df.printSchema()
istdaten_df.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+--------------+-----------+-------+--------------+---------+----------------+-------------------+---------------+----------------+-------------------+---------------+
|      date|       trip_id|operator_id|stop_id|transportation|cancelled|    arrival_time|  arrival_prognosis|arr_prog_status|  departure_time|departure_prognosis|dep_prog_status|
+----------+--------------+-----------+-------+--------------+---------+----------------+-------------------+---------------+----------------+-------------------+---------------+
|22.12.2023|85:65:8031:001|      85:65|8506101|           Zug|    false|22.12.2023 09:36|22.12.2023 09:36:34|           REAL|22.12.2023 09:36|22.12.2023 09:36:53|           REAL|
|22.12.2023|85:65:8031:001|      85:65|8506102|           Zug|    false|22.12.2023 09:39|22.12.2023 09:39:08|           REAL|22.12.2023 09:39|22.12.2023 09:39:26|           REAL|
|22.12.2023|85:65:8031:001|      85:65|8506103|           Zug|    false|22.12.2023 09:42|22.12.2023 09:42

### 2) Filtering and Parsing

In [141]:
distinct_arr_prog_status = istdaten_df.select("arr_prog_status").distinct().collect()
distinct_arr_prog_status_list = [row.arr_prog_status for row in distinct_arr_prog_status]
print(distinct_arr_prog_status_list)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

['GESCHAETZT', 'PROGNOSE', 'REAL']

In [142]:
from pyspark.sql.functions import to_timestamp, col
format_date = "dd.MM.yyyy"
format_timetable = "dd.MM.yyyy HH:mm"
format_prognosis = "dd.MM.yyyy HH:mm:ss"

istdaten_df = istdaten_df.withColumn("date", to_timestamp(col("date"), format_date))\
                         .withColumn("arrival_time", to_timestamp(col("arrival_time"), format_timetable))\
                         .withColumn("arrival_prognosis", to_timestamp(col("arrival_prognosis"), format_prognosis))\
                         .withColumn("departure_time", to_timestamp(col("departure_time"), format_timetable))\
                         .withColumn("departure_prognosis", to_timestamp(col("departure_prognosis"), format_prognosis))

istdaten_df = istdaten_df.filter((col("arr_prog_status") != "") & (col("arr_prog_status") != "UNBEKANNT") & (col('cancelled') == 'false'))

istdaten_df.show(5)
istdaten_df.printSchema()
istdaten_df.count()


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------------+--------------+-----------+-------+--------------+---------+-------------------+-------------------+---------------+-------------------+-------------------+---------------+-------------+---------------+----------------+----+-----+---+-----------+----------+----------+------------+
|               date|       trip_id|operator_id|stop_id|transportation|cancelled|       arrival_time|  arrival_prognosis|arr_prog_status|     departure_time|departure_prognosis|dep_prog_status|delay_arrival|delay_departure|delay_at_station|year|month|day|day_of_week|is_weekday|is_weekend|is_peak_time|
+-------------------+--------------+-----------+-------+--------------+---------+-------------------+-------------------+---------------+-------------------+-------------------+---------------+-------------+---------------+----------------+----+-----+---+-----------+----------+----------+------------+
|2023-12-22 00:00:00|85:65:8031:001|      85:65|8506101|           Zug|    false|2023-12-22

### 3) Istdaten Feature Engineering

In [144]:
from pyspark.sql.functions import col, to_date, year, month, dayofmonth, dayofweek, expr, when
from pyspark.sql.types import IntegerType

istdaten_df = istdaten_df.withColumn("delay_arrival", (col("arrival_prognosis").cast("long") - col("arrival_time").cast("long")))\
                         .withColumn("delay_departure", (col("departure_prognosis").cast("long") - col("departure_time").cast("long")))\
                         .withColumn("delay_at_station", (col("delay_departure") - col("delay_arrival")))\
                         .withColumn("year", year(col("date")))\
                         .withColumn("month", month(col("date")))\
                         .withColumn("day", dayofmonth(col("date")))\
                         .withColumn("day_of_week", (dayofweek(col("date")) + 5) % 7)\
                         .withColumn("is_weekday", (col("day_of_week") <= 4))\
                         .withColumn("is_weekend", (col("day_of_week") > 4))\
                         .withColumn("is_peak_time",
                                     when((col("is_weekday")) & 
                                          ((col("arrival_time").between(expr("make_timestamp(year(date), month(date), day(date), 6, 30, 0)"), 
                                                                        expr("make_timestamp(year(date), month(date), day(date), 8, 30, 0)"))) |
                                           (col("arrival_time").between(expr("make_timestamp(year(date), month(date), day(date), 16, 30, 0)"), 
                                                                        expr("make_timestamp(year(date), month(date), day(date), 18, 30, 0)")))), True)
                                     .otherwise(False))
istdaten_df.show(5)
istdaten_df.printSchema()
istdaten_df.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------------+--------------+-----------+-------+--------------+---------+-------------------+-------------------+---------------+-------------------+-------------------+---------------+-------------+---------------+----------------+----+-----+---+-----------+----------+----------+------------+
|               date|       trip_id|operator_id|stop_id|transportation|cancelled|       arrival_time|  arrival_prognosis|arr_prog_status|     departure_time|departure_prognosis|dep_prog_status|delay_arrival|delay_departure|delay_at_station|year|month|day|day_of_week|is_weekday|is_weekend|is_peak_time|
+-------------------+--------------+-----------+-------+--------------+---------+-------------------+-------------------+---------------+-------------------+-------------------+---------------+-------------+---------------+----------------+----+-----+---+-----------+----------+----------+------------+
|2023-12-22 00:00:00|85:65:8031:001|      85:65|8506101|           Zug|    false|2023-12-22

### 4) Additional Features

In [146]:
from pyspark.sql.types import DoubleType
stops_df = spark.read.csv('/data/sbb/csv/timetables/stops', header=True)\
                .select(['stop_id','stop_lat', 'stop_lon'])\
                .drop(*['year','month','day'])\
                .dropDuplicates(['stop_id'])\
                .withColumnRenamed("stop_id", "stops_stop_id")\
                .withColumn("stop_lat", col("stop_lat").cast(DoubleType()))\
                .withColumn("stop_lon", col("stop_lon").cast(DoubleType()))

istdaten_df = istdaten_df.join(
    stops_df,
    istdaten_df.stop_id == stops_df.stops_stop_id,
    "left"
).drop('stops_stop_id')

istdaten_df.show(5)
istdaten_df.printSchema()
istdaten_df.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------------+--------------------+-----------+-------+--------------+---------+-------------------+-------------------+---------------+-------------------+-------------------+---------------+-------------+---------------+----------------+----+-----+---+-----------+----------+----------+------------+----------------+----------------+
|               date|             trip_id|operator_id|stop_id|transportation|cancelled|       arrival_time|  arrival_prognosis|arr_prog_status|     departure_time|departure_prognosis|dep_prog_status|delay_arrival|delay_departure|delay_at_station|year|month|day|day_of_week|is_weekday|is_weekend|is_peak_time|        stop_lat|        stop_lon|
+-------------------+--------------------+-----------+-------+--------------+---------+-------------------+-------------------+---------------+-------------------+-------------------+---------------+-------------+---------------+----------------+----+-----+---+-----------+----------+----------+------------+------

In [159]:
from pyspark.sql.window import Window
from pyspark.sql.functions import udf, count, lag, col, to_timestamp, concat_ws, lit, lpad
from pyspark.sql.types import DoubleType
import math

# Haversine formula to calculate the distance between two points on the earth
def haversine(lat1, lon1, lat2, lon2):
    if lat1 is None or lon1 is None or lat2 is None or lon2 is None:
        return None
    R = 6371  # Radius of the Earth in kilometers
    dlat = math.radians(lat2 - lat1)
    dlon = math.radians(lon2 - lon1)
    a = math.sin(dlat / 2) * math.sin(dlat / 2) + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon / 2) * math.sin(dlon / 2)
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
    distance = R * c  # Distance in kilometers
    return distance

# Register the UDF
haversine_udf = udf(haversine, DoubleType())


stop_times_df = spark.read.csv('/data/sbb/csv/timetables/stop_times/year=2023', header=True)\
                             .withColumn("year", lit(2023))\
                             .withColumn("month", lpad(col("month").cast("string"), 2, "0"))\
                             .withColumn("day", lpad(col("day").cast("string"), 2, "0"))\
                             .withColumn("arrival_time", concat_ws(" ", concat_ws("-", col("year"), col("month"), col("day")), col("arrival_time")))\
                             .withColumn("departure_time", concat_ws(" ", concat_ws("-", col("year"), col("month"), col("day")), col("departure_time")))\
                             .withColumn("arrival_time", to_timestamp(col("arrival_time"), "yyyy-MM-dd HH:mm:ss"))\
                             .withColumn("departure_time", to_timestamp(col("departure_time"), "yyyy-MM-dd HH:mm:ss"))

stop_times_df = stop_times_df.join(
    stops_df,
    stop_times_df.stop_id == stops_df.stops_stop_id,
    "left"
).drop('stops_stop_id')

# Add a column representing the number of stops for each trip_id
window_spec_trip = Window.partitionBy("trip_id", "year", "month", "day")
stop_times_df = stop_times_df.withColumn("number_of_stops", count("stop_id").over(window_spec_trip))

# Add columns for the previous stop_id and departure_time
window_spec_seq = Window.partitionBy("trip_id", "year", "month", "day").orderBy("stop_sequence")
stop_times_df = stop_times_df.withColumn("prev_stop_id", lag("stop_id", 1).over(window_spec_seq))\
                             .withColumn("prev_departure_time", lag("departure_time", 1).over(window_spec_seq))\
                             .withColumn("prev_stop_lat", lag("stop_lat", 1).over(window_spec_seq))\
                             .withColumn("prev_stop_lon", lag("stop_lon", 1).over(window_spec_seq))\
                             .withColumn("traveling_time_min", ((col("arrival_time").cast("long") - col("prev_departure_time").cast("long"))) / 60)\
                             .withColumn("traveling_distance_km", haversine_udf(col("prev_stop_lat"), col("prev_stop_lon"),col("stop_lat"), col("stop_lon")))
 
stop_times_df.show(5)
stop_times_df.printSchema()
stop_times_df.count()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+-------------------+-------------------+-------+-------------+-----------+-------------+-----+---+----+----------------+----------------+---------------+------------+-------------------+----------------+----------------+------------------+---------------------+
|             trip_id|       arrival_time|     departure_time|stop_id|stop_sequence|pickup_type|drop_off_type|month|day|year|        stop_lat|        stop_lon|number_of_stops|prev_stop_id|prev_departure_time|   prev_stop_lat|   prev_stop_lon|traveling_time_min|traveling_distance_km|
+--------------------+-------------------+-------------------+-------+-------------+-----------+-------------+-----+---+----+----------------+----------------+---------------+------------+-------------------+----------------+----------------+------------------+---------------------+
|1.TA.91-10-j23-1.1.H|2023-08-23 19:42:00|2023-08-23 19:42:00|8500065|            1|          0|            0|   08| 23|2023|47.4837805948106|7.5460

In [172]:
right = stop_times_df.select(['trip_id', 'stop_id', 'arrival_time', 'stop_sequence', 'number_of_stops', 'traveling_time_min', 'traveling_distance_km'])\
                            .dropDuplicates(['trip_id', 'stop_id', 'arrival_time'])
left = istdaten_df.select(['date','stop_id', 'transportation', 'trip_id', 'delay_arrival', 'delay_departure', 'delay_at_station', 'day_of_week', 'is_weekend', 'is_peak_time', 'stop_lat', 'stop_lon'])

merge = left.join(
    right,
    (left.trip_id == right.trip_id) & (left.stop_id == right.stop_id) & (left.date.cast('date') == right.arrival_time.cast('date')),
    "left"
)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [174]:
merge.show(10)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------------+-------+--------------+--------------------+-------------+---------------+----------------+-----------+----------+------------+----------------+----------------+-------+-------+------------+-------------+---------------+------------------+---------------------+
|               date|stop_id|transportation|             trip_id|delay_arrival|delay_departure|delay_at_station|day_of_week|is_weekend|is_peak_time|        stop_lat|        stop_lon|trip_id|stop_id|arrival_time|stop_sequence|number_of_stops|traveling_time_min|traveling_distance_km|
+-------------------+-------+--------------+--------------------+-------------+---------------+----------------+-----------+----------+------------+----------------+----------------+-------+-------+------------+-------------+---------------+------------------+---------------------+
|2023-12-30 00:00:00|8506290|           Zug|      85:22:1073:000|          -77|             74|             151|          5|      true|       false|47.

### TODO: Figure Out How to Merge Istdaten with other tables. Currently trip_id mismatch...

In [163]:
stop_times_df.select(['trip_id', 'stop_id', 'arrival_time', 'stop_sequence', 'number_of_stops', 'traveling_time_min', 'traveling_distance_km']).show(10)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+------------+-------------------+-------------+---------------+------------------+---------------------+
|             trip_id|     stop_id|       arrival_time|stop_sequence|number_of_stops|traveling_time_min|traveling_distance_km|
+--------------------+------------+-------------------+-------------+---------------+------------------+---------------------+
|1.TA.91-10-j23-1.1.H|     8500065|2023-08-23 19:42:00|            1|              4|              null|                 null|
|1.TA.91-10-j23-1.1.H|     8588795|2023-08-23 19:44:00|            2|              4|               2.0|   1.4158613162755302|
|1.TA.91-10-j23-1.1.H|     8500064|2023-08-23 19:45:00|            3|              4|               1.0|   0.5080517179755902|
|1.TA.91-10-j23-1.1.H|     8500072|2023-08-23 19:47:00|            4|              4|               2.0|   0.8143225302094911|
|1.TA.91-10-j23-1.1.H|     8500065|2023-10-18 19:57:00|            1|              4|              null|       

In [168]:
istdaten_df.select(['date','stop_id', 'transportation', 'trip_id', 'delay_arrival', 'delay_departure', 'delay_at_station', 'day_of_week', 'is_weekend', 'is_peak_time', 'stop_lat', 'stop_lon']).show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------------+-------+--------------+--------------------+-------------+---------------+----------------+-----------+----------+------------+----------------+----------------+
|               date|stop_id|transportation|             trip_id|delay_arrival|delay_departure|delay_at_station|day_of_week|is_weekend|is_peak_time|        stop_lat|        stop_lon|
+-------------------+-------+--------------+--------------------+-------------+---------------+----------------+-----------+----------+------------+----------------+----------------+
|2023-10-11 00:00:00|8589289|           Bus|85:823:36715-00009-1|           61|              9|             -52|          2|     false|       false|47.5528391160685|7.55672698574657|
|2023-11-06 00:00:00|8588156|           Bus|85:151:TL031-4506...|           89|            105|              16|          0|     false|       false|46.5112787626392|6.55061386739266|
|2023-11-06 00:00:00|8588156|           Bus|85:151:TL031-4506...|            0|      

In [175]:
spark.read.orc('/data/sbb/orc/istdaten/year=2023').show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------+----------------+------------+-------------+--------------+----------+---------+-----------+---------+-------------------+--------------+-------------+-------+--------------------+----------------+-------------------+------------------+----------------+-------------------+------------------+-------------+-----+
|betriebstag|fahrt_bezeichner|betreiber_id|betreiber_abk|betreiber_name|produkt_id|linien_id|linien_text|umlauf_id|verkehrsmittel_text|zusatzfahrt_tf|faellt_aus_tf|  bpuic|   haltestellen_name|    ankunftszeit|        an_prognose|an_prognose_status|    abfahrtszeit|        ab_prognose|ab_prognose_status|durchfahrt_tf|month|
+-----------+----------------+------------+-------------+--------------+----------+---------+-----------+---------+-------------------+--------------+-------------+-------+--------------------+----------------+-------------------+------------------+----------------+-------------------+------------------+-------------+-----+
| 22.12.2023|  85:65:8

# Delay Prediction Model

In [176]:
features = istdaten_df.select(['date','stop_id', 'transportation', 'trip_id', 'delay_arrival', 'delay_departure', 'delay_at_station', 'day_of_week', 'is_weekend', 'is_peak_time', 'stop_lat', 'stop_lon'])

features.show(10)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------------+-------+--------------+--------------------+-------------+---------------+----------------+-----------+----------+------------+----------------+----------------+
|               date|stop_id|transportation|             trip_id|delay_arrival|delay_departure|delay_at_station|day_of_week|is_weekend|is_peak_time|        stop_lat|        stop_lon|
+-------------------+-------+--------------+--------------------+-------------+---------------+----------------+-----------+----------+------------+----------------+----------------+
|2023-03-23 00:00:00|8501214|         Metro|85:151:TL070-4506...|           81|            155|              74|          3|     false|       false|46.5221956281987| 6.5661367555044|
|2023-12-22 00:00:00|8506100|           Zug|      85:65:8032:001|           11|             18|               7|          4|     false|        true| 47.558162001326|8.89656423219735|
|2023-12-22 00:00:00|8506101|           Zug|      85:65:8031:001|           34|      

In [177]:
distinct_transportation = istdaten_df.select("transportation").distinct().collect()
distinct_transportation = [row.transportation for row in distinct_transportation]
print(distinct_transportation)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

['BUS', 'Tram', 'Zug', 'Bus', '', 'Zahnradbahn', 'Metro', 'CS', 'WM-BUS', 'Taxi']

In [186]:
from pyspark.sql.functions import col, lower, when
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.ml import Pipeline

features = istdaten_df.select(['stop_id', 'transportation', 'delay_arrival', 'day_of_week', 'is_weekend', 'is_peak_time', 'stop_lat', 'stop_lon'])\
                        .withColumn("transportation", lower(col("transportation")))\
                        .withColumn("transportation",\
                           when(col("transportation")=="" ,None) \
                              .otherwise(col("transportation")))\
                        .fillna({
                            "delay_arrival": 0
                        })

transportation_indexer = StringIndexer(inputCol="transportation", outputCol="transportation_index")
transportation_encoder = OneHotEncoder(inputCol="transportation_index", outputCol="transportation_vec")

assembler = VectorAssembler(
    inputCols=[
        "transportation_vec",
        "day_of_week",
        "is_weekend",
        "is_peak_time",
        "stop_lat",
        "stop_lon"
    ],
    outputCol="features"
)

pipeline = Pipeline(stages=[
    transportation_indexer,
    transportation_encoder,
    assembler
])

# Fit the pipeline to the data
pipeline_model = pipeline.fit(features)
prepared_features = pipeline_model.transform(features)

# Select the features and the target variable for the model
final_data = prepared_features.select(col("features"), col("delay_arrival").alias("label"))

final_data.show(5)


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+-----+
|            features|label|
+--------------------+-----+
|(12,[0,7,10,11],[...|    2|
|(12,[0,7,10,11],[...|   46|
|(12,[0,7,9,10,11]...|  228|
|(12,[0,7,10,11],[...|  -20|
|(12,[0,7,9,10,11]...|   60|
+--------------------+-----+
only showing top 5 rows

In [187]:
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql.functions import col
from pyspark.ml import Pipeline

# Split the data into training and testing sets
train_data, test_data = final_data.randomSplit([0.8, 0.2], seed=42)

# Initialize the RandomForestRegressor
rf = RandomForestRegressor(featuresCol="features", labelCol="label", numTrees=100, maxDepth=10)

# Create the pipeline
pipeline = Pipeline(stages=[rf])

# Train the model
model = pipeline.fit(train_data)

# Make predictions
predictions = model.transform(test_data)

# Evaluate the model
evaluator = RegressionEvaluator(labelCol="label", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print(f"Root Mean Squared Error (RMSE) on test data: {rmse}")

# Show some predictions
predictions.select("prediction", "label", "features").show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

An error was encountered:
An error occurred while calling o3957.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 767.0 failed 4 times, most recent failure: Lost task 0.3 in stage 767.0 (TID 45075) (iccluster072.iccluster.epfl.ch executor 2622): org.apache.spark.SparkException: Failed to execute user defined function (StringIndexerModel$$Lambda$5584/0x00007f1dfd5559b0: (string) => double)
	at org.apache.spark.sql.errors.QueryExecutionErrors$.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala:190)
	at org.apache.spark.sql.errors.QueryExecutionErrors.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage5.sort_addToSorter_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage5.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(Buffered