In [1]:
from pyspark.ml import PipelineModel
from pyspark.sql.functions import col, lit, to_timestamp
from pyspark.sql.types import IntegerType

model_path = "gs://model_stored_for_pred_forecast/hourly_prediction/"
model = PipelineModel.load(model_path)

                                                                                

In [2]:
from google.cloud import bigquery

# Initialize BigQuery client
client = bigquery.Client()

In [3]:
# Get November data (for moving avg calc)
query_nov = """
            CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.dec_till_30` AS
            SELECT
            transit_timestamp,
            transit_mode_index,
            station_complex_index,
            borough_index,
            payment_method_index,
            ridership AS actual_ridership,
            ridership AS prediction,
            hour_of_day,
            hour_of_day_sin,
            hour_of_day_cos,
            day_of_week,
            day_of_week_sin,
            day_of_week_cos,
            week_of_month,
            week_of_month_sin,
            week_of_month_cos,
            ridership_lag_1,
            ridership_lag_2,
            ridership_lag_3,
            ridership_lag_4,
            ridership_lag_5,
            ridership_lag_6,
            ridership_lag_7,
            ridership_lag_8,
            ridership_lag_9,
            ridership_lag_10,
            ridership_lag_11,
            ridership_lag_12,
            ridership_lag_13,
            ridership_lag_14,
            ridership_lag_15,
            ridership_lag_16,
            ridership_lag_17,
            ridership_lag_18,
            ridership_lag_19,
            ridership_lag_20,
            ridership_lag_21,
            ridership_lag_22,
            ridership_lag_23,
            ridership_lag_24,
            ridership_7d_mv,
            hour_of_day_7d_mv,
            day_of_week_7d_mv,
            week_of_month_7d_mv,
            hour_of_day_sin_7d_mv,
            hour_of_day_cos_7d_mv,
            day_of_week_sin_7d_mv,
            day_of_week_cos_7d_mv,
            week_of_month_sin_7d_mv,
            week_of_month_cos_7d_mv,
            ridership_30d_mv,
            hour_of_day_30d_mv,
            day_of_week_30d_mv,
            week_of_month_30d_mv,
            hour_of_day_sin_30d_mv,
            hour_of_day_cos_30d_mv,
            day_of_week_sin_30d_mv,
            day_of_week_cos_30d_mv,
            week_of_month_sin_30d_mv,
            week_of_month_cos_30d_mv
            FROM `lively-encoder-448916-d5.nyc_subway.hour_model`
            WHERE transit_timestamp >= '2024-12-28 00:00:00 UTC' AND transit_timestamp <= '2024-12-30 23:00:00 UTC';
            """

In [4]:
# Create December data (to extract actual ridership values)
query_dec = """
            CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.dec_31` AS
            SELECT
            transit_timestamp,
            transit_mode_index,
            station_complex_index,
            borough_index,
            payment_method_index,
            ridership AS actual_ridership,
            hour_of_day,
            hour_of_day_sin,
            hour_of_day_cos,
            day_of_week,
            day_of_week_sin,
            day_of_week_cos,
            week_of_month,
            week_of_month_sin,
            week_of_month_cos,
            ridership_lag_1,
            ridership_lag_2,
            ridership_lag_3,
            ridership_lag_4,
            ridership_lag_5,
            ridership_lag_6,
            ridership_lag_7,
            ridership_lag_8,
            ridership_lag_9,
            ridership_lag_10,
            ridership_lag_11,
            ridership_lag_12,
            ridership_lag_13,
            ridership_lag_14,
            ridership_lag_15,
            ridership_lag_16,
            ridership_lag_17,
            ridership_lag_18,
            ridership_lag_19,
            ridership_lag_20,
            ridership_lag_21,
            ridership_lag_22,
            ridership_lag_23,
            ridership_lag_24,
            ridership_7d_mv,
            hour_of_day_7d_mv,
            day_of_week_7d_mv,
            week_of_month_7d_mv,
            hour_of_day_sin_7d_mv,
            hour_of_day_cos_7d_mv,
            day_of_week_sin_7d_mv,
            day_of_week_cos_7d_mv,
            week_of_month_sin_7d_mv,
            week_of_month_cos_7d_mv,
            ridership_30d_mv,
            hour_of_day_30d_mv,
            day_of_week_30d_mv,
            week_of_month_30d_mv,
            hour_of_day_sin_30d_mv,
            hour_of_day_cos_30d_mv,
            day_of_week_sin_30d_mv,
            day_of_week_cos_30d_mv,
            week_of_month_sin_30d_mv,
            week_of_month_cos_30d_mv
            FROM `lively-encoder-448916-d5.nyc_subway.hour_model`
            WHERE transit_timestamp >= '2024-12-31 00:00:00 UTC' AND transit_timestamp <= '2024-12-31 23:00:00 UTC';
            """

In [5]:
# Calculate lags
query_lags = """
                CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.dec_31_lags`
                AS
                    SELECT
                        (SELECT TIMESTAMP_ADD(MAX(transit_timestamp), INTERVAL 1 HOUR)
                            FROM `lively-encoder-448916-d5.nyc_subway.dec_till_30`) AS transit_timestamp,
                        transit_mode_index, station_complex_index, borough_index, payment_method_index,
                        prediction as ridership_lag_1,
                        ridership_lag_1 as ridership_lag_2,
                        ridership_lag_2 as ridership_lag_3,
                        ridership_lag_3 as ridership_lag_4,
                        ridership_lag_4 as ridership_lag_5,
                        ridership_lag_5 as ridership_lag_6,
                        ridership_lag_6 as ridership_lag_7,
                        ridership_lag_7 as ridership_lag_8,
                        ridership_lag_8 as ridership_lag_9,
                        ridership_lag_9 as ridership_lag_10,
                        ridership_lag_10 as ridership_lag_11,
                        ridership_lag_11 as ridership_lag_12,
                        ridership_lag_12 as ridership_lag_13,
                        ridership_lag_13 as ridership_lag_14,
                        ridership_lag_14 as ridership_lag_15,
                        ridership_lag_15 as ridership_lag_16,
                        ridership_lag_16 as ridership_lag_17,
                        ridership_lag_17 as ridership_lag_18,
                        ridership_lag_18 as ridership_lag_19,
                        ridership_lag_19 as ridership_lag_20,
                        ridership_lag_20 as ridership_lag_21,
                        ridership_lag_21 as ridership_lag_22,
                        ridership_lag_22 as ridership_lag_23,
                        ridership_lag_23 as ridership_lag_24
                    FROM `lively-encoder-448916-d5.nyc_subway.dec_till_30`
                    WHERE transit_timestamp = (SELECT MAX(transit_timestamp) FROM `lively-encoder-448916-d5.nyc_subway.dec_till_30`)
                """

In [6]:
# Calcuate moving average
query_ma = """
            CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.dec_31_ma`
            AS (
            WITH ridership_moving_avg AS (
                SELECT
                    transit_timestamp,
                    transit_mode_index,
                    station_complex_index,
                    borough_index,
                    payment_method_index,
                    -- 7-day moving average
                    AVG(prediction) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS ridership_7d_mv,
                    AVG(hour_of_day) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS hour_of_day_7d_mv,
                    AVG(day_of_week) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS day_of_week_7d_mv,
                    AVG(week_of_month) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS week_of_month_7d_mv,
                    AVG(hour_of_day_sin) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS hour_of_day_sin_7d_mv,
                        AVG(hour_of_day_cos) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS hour_of_day_cos_7d_mv,
                    AVG(day_of_week_sin) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS day_of_week_sin_7d_mv,
                    AVG(day_of_week_cos) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS day_of_week_cos_7d_mv,
                    AVG(week_of_month_sin) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS week_of_month_sin_7d_mv,
                    AVG(week_of_month_cos) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS week_of_month_cos_7d_mv,
                    -- 30-day moving average
                    AVG(prediction) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS ridership_30d_mv,
                    AVG(hour_of_day) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS hour_of_day_30d_mv,
                    AVG(day_of_week) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS day_of_week_30d_mv,
                    AVG(week_of_month) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS week_of_month_30d_mv,
                    AVG(hour_of_day_sin) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS hour_of_day_sin_30d_mv,
                    AVG(hour_of_day_cos) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS hour_of_day_cos_30d_mv,
                    AVG(day_of_week_sin) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS day_of_week_sin_30d_mv,
                    AVG(day_of_week_cos) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS day_of_week_cos_30d_mv,
                    AVG(week_of_month_sin) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS week_of_month_sin_30d_mv,
                    AVG(week_of_month_cos) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS week_of_month_cos_30d_mv
                FROM `lively-encoder-448916-d5.nyc_subway.dec_till_30`
                WHERE transit_timestamp >=  (SELECT TIMESTAMP_SUB(MAX(transit_timestamp), INTERVAL 29 HOUR) 
                                            FROM `lively-encoder-448916-d5.nyc_subway.dec_till_30`) 
                AND transit_timestamp <= (SELECT MAX(transit_timestamp) 
                                            FROM `lively-encoder-448916-d5.nyc_subway.dec_till_30`)
            )
            SELECT
                    (SELECT TIMESTAMP_ADD(MAX(transit_timestamp), INTERVAL 1 HOUR)
                    FROM `lively-encoder-448916-d5.nyc_subway.dec_till_30`) AS next_transit_timestamp,
                *
            FROM
                ridership_moving_avg
            WHERE transit_timestamp = (SELECT MAX(transit_timestamp) FROM `lively-encoder-448916-d5.nyc_subway.dec_till_30`)
            ORDER BY transit_timestamp DESC
            );
            """

In [7]:
# Combine lags and ma tables to create input df for model
query_input = """
                CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.dec_31_input`
                AS (
                SELECT
                    l.transit_timestamp,
                    l.transit_mode_index,
                    l.station_complex_index,
                    l.borough_index,
                    l.payment_method_index,
                    EXTRACT(HOUR FROM l.transit_timestamp) AS hour_of_day,
                    SIN(2 * ACOS(-1) * EXTRACT(HOUR FROM l.transit_timestamp) / 24) AS hour_of_day_sin,
                    COS(2 * ACOS(-1) * EXTRACT(HOUR FROM l.transit_timestamp) / 24) AS hour_of_day_cos,
                    EXTRACT(DAYOFWEEK FROM l.transit_timestamp) AS day_of_week,
                    SIN(2 * ACOS(-1) * EXTRACT(DAYOFWEEK FROM l.transit_timestamp) / 7) AS day_of_week_sin,
                    COS(2 * ACOS(-1) * EXTRACT(DAYOFWEEK FROM  l.transit_timestamp) / 7) AS day_of_week_cos,
                    CEIL(EXTRACT(DAY FROM  l.transit_timestamp) / 7) AS week_of_month,
                    SIN(2 * ACOS(-1) * CEIL(EXTRACT(DAY FROM  l.transit_timestamp) / 7) / 5) AS week_of_month_sin,
                    COS(2 * ACOS(-1) * CEIL(EXTRACT(DAY FROM  l.transit_timestamp) / 7) / 5) AS week_of_month_cos,
                    l.ridership_lag_1, l.ridership_lag_2, l.ridership_lag_3, l.ridership_lag_4, l.ridership_lag_5,
                    l.ridership_lag_6, l.ridership_lag_7, l.ridership_lag_8, l.ridership_lag_9, l.ridership_lag_10,
                    l.ridership_lag_11, l.ridership_lag_12, l.ridership_lag_13, l.ridership_lag_14, l.ridership_lag_15,
                    l.ridership_lag_16, l.ridership_lag_17, l.ridership_lag_18, l.ridership_lag_19, l.ridership_lag_20,
                    l.ridership_lag_21, l.ridership_lag_22, l.ridership_lag_23, l.ridership_lag_24,
                    mv.ridership_7d_mv,
                    mv.hour_of_day_7d_mv, mv.day_of_week_7d_mv, mv.week_of_month_7d_mv,
                    mv.hour_of_day_sin_7d_mv, mv.hour_of_day_cos_7d_mv,
                    mv.day_of_week_sin_7d_mv, mv.day_of_week_cos_7d_mv,
                    mv.week_of_month_sin_7d_mv, mv.week_of_month_cos_7d_mv,
                    mv.ridership_30d_mv,
                    mv.hour_of_day_30d_mv, mv.day_of_week_30d_mv, mv.week_of_month_30d_mv,
                    mv.hour_of_day_sin_30d_mv, mv.hour_of_day_cos_30d_mv,
                    mv.day_of_week_sin_30d_mv, mv.day_of_week_cos_30d_mv,
                    mv.week_of_month_sin_30d_mv, mv.week_of_month_cos_30d_mv
                    FROM `lively-encoder-448916-d5.nyc_subway.dec_31_lags` l
                    INNER JOIN `lively-encoder-448916-d5.nyc_subway.dec_31_ma` mv
                    ON l.transit_mode_index = mv.transit_mode_index
                    AND l.station_complex_index = mv.station_complex_index
                    AND l.borough_index = mv.borough_index
                    AND l.payment_method_index = mv.payment_method_index
                    );
                """

In [8]:
query_buffer = """
                CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.dec_31_buffer` AS
                SELECT
                  -- Columns from dec_input
                  dfi.transit_timestamp,
                  dfi.transit_mode_index,
                  dfi.station_complex_index,
                  dfi.borough_index,
                  dfi.payment_method_index,
                  -- Column from dec_second_half
                  dd.actual_ridership,
                  -- Column from dec_output
                  dfo.prediction,
                  -- Other columns from dec_input
                  dfi.hour_of_day,
                  dfi.hour_of_day_sin,
                  dfi.hour_of_day_cos,
                  dfi.day_of_week,
                  dfi.day_of_week_sin,
                  dfi.day_of_week_cos,
                  dfi.week_of_month,
                  dfi.week_of_month_sin,
                  dfi.week_of_month_cos,
                  dfi.ridership_lag_1,
                  dfi.ridership_lag_2,
                  dfi.ridership_lag_3,
                  dfi.ridership_lag_4,
                  dfi.ridership_lag_5,
                  dfi.ridership_lag_6,
                  dfi.ridership_lag_7,
                  dfi.ridership_lag_8,
                  dfi.ridership_lag_9,
                  dfi.ridership_lag_10,
                  dfi.ridership_lag_11,
                  dfi.ridership_lag_12,
                  dfi.ridership_lag_13,
                  dfi.ridership_lag_14,
                  dfi.ridership_lag_15,
                  dfi.ridership_lag_16,
                  dfi.ridership_lag_17,
                  dfi.ridership_lag_18,
                  dfi.ridership_lag_19,
                  dfi.ridership_lag_20,
                  dfi.ridership_lag_21,
                  dfi.ridership_lag_22,
                  dfi.ridership_lag_23,
                  dfi.ridership_lag_24,
                  dfi.ridership_7d_mv,
                  dfi.hour_of_day_7d_mv,
                  dfi.day_of_week_7d_mv,
                  dfi.week_of_month_7d_mv,
                  dfi.hour_of_day_sin_7d_mv,
                  dfi.hour_of_day_cos_7d_mv,
                  dfi.day_of_week_sin_7d_mv,
                  dfi.day_of_week_cos_7d_mv,
                  dfi.week_of_month_sin_7d_mv,
                  dfi.week_of_month_cos_7d_mv,
                  dfi.ridership_30d_mv,
                  dfi.hour_of_day_30d_mv,
                  dfi.day_of_week_30d_mv,
                  dfi.week_of_month_30d_mv,
                  dfi.hour_of_day_sin_30d_mv,
                  dfi.hour_of_day_cos_30d_mv,
                  dfi.day_of_week_sin_30d_mv,
                  dfi.day_of_week_cos_30d_mv,
                  dfi.week_of_month_sin_30d_mv,
                  dfi.week_of_month_cos_30d_mv

                FROM `lively-encoder-448916-d5.nyc_subway.dec_31_input` AS dfi
                INNER JOIN `lively-encoder-448916-d5.nyc_subway.dec_31_output` AS dfo
                  ON dfi.transit_timestamp = dfo.transit_timestamp
                  AND dfi.transit_mode_index = dfo.transit_mode_index
                  AND dfi.station_complex_index = dfo.station_complex_index
                  AND dfi.borough_index = dfo.borough_index
                  AND dfi.payment_method_index = dfo.payment_method_index

                INNER JOIN `lively-encoder-448916-d5.nyc_subway.dec_31` AS dd
                  ON dfi.transit_timestamp = dd.transit_timestamp
                  AND dfi.transit_mode_index = dd.transit_mode_index
                  AND dfi.station_complex_index = dd.station_complex_index
                  AND dfi.borough_index = dd.borough_index
                  AND dfi.payment_method_index = dd.payment_method_index;
                """

In [9]:
query_union = """
                CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.dec_till_30` AS
                SELECT * FROM `lively-encoder-448916-d5.nyc_subway.dec_till_30`
                UNION ALL
                SELECT * FROM `lively-encoder-448916-d5.nyc_subway.dec_31_buffer`
                """

In [10]:
# # Create November and December tables
# client.query(query_nov).result()
# client.query(query_dec).result()

In [None]:
import time

for i in range(1, 4):
    
    # Add lags, mv, combine for input
    client.query(query_lags).result()
    client.query(query_ma).result()
    client.query(query_input).result()
    time.sleep(4)
    
    # Read input data from BigQuery table into a Spark DataFrame
    input_hourly_df = spark.read \
        .format("bigquery") \
        .option("table", "lively-encoder-448916-d5.nyc_subway.dec_31_input") \
        .option("parentProject", "lively-encoder-448916-d5") \
        .load()

    # Extract the first value of the `df_transit_timestamp` column
    df_transit_timestamp = input_hourly_df.select("transit_timestamp").first()["transit_timestamp"]
    print(df_transit_timestamp)

    # Drop the `df_transit_timestamp` column from the DataFrame before inputting to the model
    input_hourly_df = input_hourly_df.drop("transit_timestamp")

    # Get predictions
    output_predictions = model.transform(input_hourly_df)
    output_predictions = output_predictions.withColumn("transit_timestamp", to_timestamp(lit(df_transit_timestamp)))
    output_predictions = output_predictions.withColumn("prediction", col("prediction").cast(IntegerType()))
    output_predictions = output_predictions.select('transit_timestamp', 'transit_mode_index', 'station_complex_index', 'borough_index', 'payment_method_index', 'prediction')

    # Write predictions
    output_predictions.write.format("bigquery") \
        .option("table", "lively-encoder-448916-d5.nyc_subway.dec_31_output") \
        .option("temporaryGcsBucket", "temp_dec_forecast_bucket") \
        .mode("overwrite") \
        .save()
    
    time.sleep(4)
    # Update recent_dates_df
    client.query(query_buffer).result()
    client.query(query_union).result()
    
    time.sleep(1)
    
    

In [None]:
query_dec_avp = """
                CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.dec_31_avp` AS
                SELECT
                    transit_timestamp, transit_mode_index,
                    station_complex_index, borough_index,
                    payment_method_index,
                    actual_ridership, prediction
                FROM `lively-encoder-448916-d5.nyc_subway.dec_till_30`
                WHERE transit_timestamp > '2024-12-30 23:00:00 UTC'
                ORDER BY transit_timestamp DESC;
                """
client.query(query_dec_avp).result()

In [None]:
dec_31_avp = spark.read \
                    .format("bigquery") \
                    .option("table", "lively-encoder-448916-d5.nyc_subway.dec_31_avp") \
                    .option("parentProject", "lively-encoder-448916-d5") \
                    .load()

In [None]:
from pyspark.sql.functions import col, lit, create_map, round, to_date

# Map encoded categorical columns to their original values
payment_method_mapping = {1: 'metrocard', 2: 'omny'}
transit_mode_mapping = {1: 'staten_island_railway', 2: 'subway', 3: 'tram'}
borough_mapping = {1: 'Bronx', 2: 'Brooklyn', 3: 'Manhattan', 4: 'Queens', 5: 'Staten Island'}

station_df = spark.read.csv("gs://bucket_jars/station_complex_csv/station_complex.csv", header=True)
station_df = station_df.withColumn("station_complex_index", col("station_complex_index").cast("int"))
station_df = station_df.select("station_complex_index", "station_complex")
# station_df.printSchema()

station_map = station_df.rdd.collectAsMap()
# print(station_mapping)
station_mapping = [
    item for pair in station_map.items() for item in (lit(pair[0]), lit(pair[1]))
]

# Create mapping expressions
payment_method_expr = create_map([lit(k) for pair in payment_method_mapping.items() for k in pair])
transit_mode_expr = create_map([lit(k) for pair in transit_mode_mapping.items() for k in pair])
borough_expr = create_map([lit(k) for pair in borough_mapping.items() for k in pair])
station_complex_expr = create_map(*station_mapping)

# Apply the mappings
dec_31_avp_mapped = dec_31_avp.withColumn(
                        "payment_method", 
                        payment_method_expr[col("payment_method_index").cast("int")]
                    ).withColumn(
                        "transit_mode", 
                        transit_mode_expr[col("transit_mode_index").cast("int")]
                    ).withColumn(
                        "borough", 
                        borough_expr[col("borough_index").cast("int")]
                    ).withColumn(
                        "station_complex", 
                        station_complex_expr[col("station_complex_index").cast("int")]
                    )

In [None]:
# dec_31_avp_mapped.show(5)

In [None]:
dec_31_avp_mapped = dec_31_avp_mapped.select('transit_timestamp', 'transit_mode', 'station_complex', 'borough', 'payment_method', 'actual_ridership', 'prediction')
dec_31_avp_mapped.show(3)

In [None]:
# Write forecast
dec_31_avp_mapped.write.format("bigquery") \
    .option("table", "lively-encoder-448916-d5.nyc_subway.dec_31_avp") \
    .option("temporaryGcsBucket", "temp_dec_forecast_bucket") \
    .mode("overwrite") \
    .save()