In [1]:
from pyspark.ml import PipelineModel

model_path = "gs://model_stored_for_pred_forecast/hourly_prediction/"
model = PipelineModel.load(model_path)

                                                                                

In [2]:
type(model)

pyspark.ml.pipeline.PipelineModel

In [3]:
model.stages

[VectorAssembler_5e181915096d, SparkXGBRegressor_ac03c6c3bb22]

In [7]:
# model.stages[0].getInputCols()  # Shows the input features

In [9]:
# Configure the BigQuery table
project_id = "lively-encoder-448916-d5"
dataset_id = "nyc_subway"
table_id = "hourly_future_pred_input"

# Read the BigQuery table into a Spark DataFrame
input_hourly_df = spark.read \
    .format("bigquery") \
    .option("table", f"{project_id}.{dataset_id}.{table_id}") \
    .option("parentProject", project_id) \
    .load()

# Show the DataFrame
input_hourly_df.show(1)

25/04/19 20:11:18 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

+------------------+---------------------+-------------+--------------------+-----------+---------------+---------------+-----------+------------------+-------------------+-------------+------------------+-------------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+---------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+----------------+------------------+-----------------+-----------------+-------------------+---------------------+---------------------+---------------------+---------------------+-----------------------+-----------------------+------------------+------------------+------------------+--------------------+----------------------+----------------------+----------------------+------------------

In [10]:
output_predictions = model.transform(input_hourly_df)
output_predictions = output_predictions.select('transit_mode_index', 'station_complex_index', 'borough_index', 'payment_method_index', 'prediction')

In [12]:
output_predictions.show(5)

[Stage 10:>                                                         (0 + 1) / 1]

+------------------+---------------------+-------------+--------------------+-----------------+
|transit_mode_index|station_complex_index|borough_index|payment_method_index|       prediction|
+------------------+---------------------+-------------+--------------------+-----------------+
|                 2|                  316|            1|                   2|  5.5547194480896|
|                 2|                  390|            1|                   2|9.157230377197266|
|                 2|                  211|            1|                   1|8.512842178344727|
|                 2|                  300|            1|                   1|12.43112850189209|
|                 2|                   82|            1|                   2|26.37932586669922|
+------------------+---------------------+-------------+--------------------+-----------------+
only showing top 5 rows



                                                                                

In [11]:
from pyspark.sql.functions import col, lit, create_map, round

# Map encoded categorical columns to their original values
payment_method_mapping = {1: 'metrocard', 2: 'omny'}
transit_mode_mapping = {1: 'staten_island_railway', 2: 'subway', 3: 'tram'}
borough_mapping = {1: 'Bronx', 2: 'Brooklyn', 3: 'Manhattan', 4: 'Queens', 5: 'Staten Island'}

station_df = spark.read.csv("gs://bucket_jars/station_complex_csv/station_complex.csv", header=True)
station_df = station_df.withColumn("station_complex_index", col("station_complex_index").cast("int"))
station_df = station_df.select("station_complex_index", "station_complex")
# station_df.printSchema()

station_map = station_df.rdd.collectAsMap()
# print(station_mapping)
station_mapping = [
    item for pair in station_map.items() for item in (lit(pair[0]), lit(pair[1]))
]

# Create mapping expressions
payment_method_expr = create_map([lit(k) for pair in payment_method_mapping.items() for k in pair])
transit_mode_expr = create_map([lit(k) for pair in transit_mode_mapping.items() for k in pair])
borough_expr = create_map([lit(k) for pair in borough_mapping.items() for k in pair])
station_complex_expr = create_map(*station_mapping)

# Apply the mappings
predictions_df = output_predictions.withColumn(
                        "payment_method", 
                        payment_method_expr[col("payment_method_index").cast("int")]
                    ).withColumn(
                        "transit_mode", 
                        transit_mode_expr[col("transit_mode_index").cast("int")]
                    ).withColumn(
                        "borough", 
                        borough_expr[col("borough_index").cast("int")]
                    ).withColumn(
                        "station_complex", 
                        station_complex_expr[col("station_complex_index").cast("int")]
                    )

predictions_df = predictions_df.select('transit_mode', 'station_complex', 'borough', 'payment_method', 'prediction')
predictions_df = predictions_df.withColumn("prediction", round("prediction", 0))
predictions_df = predictions_df.withColumn("prediction", col("prediction").cast("integer"))
predictions_df.show(5, truncate=False)

[Stage 9:>                                                          (0 + 1) / 1]

+------------+--------------------+-------+--------------+----------+
|transit_mode|station_complex     |borough|payment_method|prediction|
+------------+--------------------+-------+--------------+----------+
|subway      |Kingsbridge Rd (B,D)|Bronx  |omny          |6         |
|subway      |Simpson St (2,5)    |Bronx  |omny          |9         |
|subway      |Burnside Av (4)     |Bronx  |metrocard     |9         |
|subway      |Hunts Point Av (6)  |Bronx  |metrocard     |12        |
|subway      |3 Av-149 St (2,5)   |Bronx  |omny          |26        |
+------------+--------------------+-------+--------------+----------+
only showing top 5 rows



                                                                                

In [13]:
# Write predictions to BigQuery
predictions_df.write.format("bigquery") \
    .option("table", "lively-encoder-448916-d5.nyc_subway.hourly_future_pred_output") \
    .option("temporaryGcsBucket", "temp_nyc_bucket_for_bq") \
    .mode("overwrite") \
    .save()

                                                                                