In [None]:
from pyspark.ml import PipelineModel
from pyspark.sql.functions import col, lit, to_timestamp
from pyspark.sql.types import IntegerType

model_path = "gs://model_stored_for_pred_forecast/hourly_prediction/"
model = PipelineModel.load(model_path)

In [None]:
from google.cloud import bigquery

# Initialize BigQuery client
client = bigquery.Client()

In [None]:
# Step 1: Create dataset with recent dates (MAX_DATE - 30 for moving avg calculation) 
query_recent_data = """
                    CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.recent_hours_data` AS
                    SELECT
                    transit_timestamp,
                    transit_mode_index,
                    station_complex_index,
                    borough_index,
                    payment_method_index,
                    ridership AS prediction,
                    hour_of_day,
                    hour_of_day_sin,
                    hour_of_day_cos,
                    day_of_week,
                    day_of_week_sin,
                    day_of_week_cos,
                    week_of_month,
                    week_of_month_sin,
                    week_of_month_cos,
                    ridership_lag_1,
                    ridership_lag_2,
                    ridership_lag_3,
                    ridership_lag_4,
                    ridership_lag_5,
                    ridership_lag_6,
                    ridership_lag_7,
                    ridership_lag_8,
                    ridership_lag_9,
                    ridership_lag_10,
                    ridership_lag_11,
                    ridership_lag_12,
                    ridership_lag_13,
                    ridership_lag_14,
                    ridership_lag_15,
                    ridership_lag_16,
                    ridership_lag_17,
                    ridership_lag_18,
                    ridership_lag_19,
                    ridership_lag_20,
                    ridership_lag_21,
                    ridership_lag_22,
                    ridership_lag_23,
                    ridership_lag_24,
                    ridership_7d_mv,
                    hour_of_day_7d_mv,
                    day_of_week_7d_mv,
                    week_of_month_7d_mv,
                    hour_of_day_sin_7d_mv,
                    hour_of_day_cos_7d_mv,
                    day_of_week_sin_7d_mv,
                    day_of_week_cos_7d_mv,
                    week_of_month_sin_7d_mv,
                    week_of_month_cos_7d_mv,
                    ridership_30d_mv,
                    hour_of_day_30d_mv,
                    day_of_week_30d_mv,
                    week_of_month_30d_mv,
                    hour_of_day_sin_30d_mv,
                    hour_of_day_cos_30d_mv,
                    day_of_week_sin_30d_mv,
                    day_of_week_cos_30d_mv,
                    week_of_month_sin_30d_mv,
                    week_of_month_cos_30d_mv
                    FROM `lively-encoder-448916-d5.nyc_subway.hour_model`
                    WHERE transit_timestamp >=  (SELECT TIMESTAMP_SUB(MAX(transit_timestamp), INTERVAL 29 HOUR) 
                                     FROM `lively-encoder-448916-d5.nyc_subway.hour_model`) 
                    AND transit_timestamp <= (SELECT MAX(transit_timestamp) 
                                    FROM `lively-encoder-448916-d5.nyc_subway.hour_model`)
                    """

In [None]:
# Step 2: Calculate lags
query_lags = """
            CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.hourly_forecast_lags`
            AS
            SELECT
                (SELECT TIMESTAMP_ADD(MAX(transit_timestamp), INTERVAL 1 HOUR)
                FROM `lively-encoder-448916-d5.nyc_subway.recent_hours_data`) AS next_transit_timestamp,
                transit_mode_index, station_complex_index, borough_index, payment_method_index,
                prediction as ridership_lag_1,
                ridership_lag_1 as ridership_lag_2,
                ridership_lag_2 as ridership_lag_3,
                ridership_lag_3 as ridership_lag_4,
                ridership_lag_4 as ridership_lag_5,
                ridership_lag_5 as ridership_lag_6,
                ridership_lag_6 as ridership_lag_7,
                ridership_lag_7 as ridership_lag_8,
                ridership_lag_8 as ridership_lag_9,
                ridership_lag_9 as ridership_lag_10,
                ridership_lag_10 as ridership_lag_11,
                ridership_lag_11 as ridership_lag_12,
                ridership_lag_12 as ridership_lag_13,
                ridership_lag_13 as ridership_lag_14,
                ridership_lag_14 as ridership_lag_15,
                ridership_lag_15 as ridership_lag_16,
                ridership_lag_16 as ridership_lag_17,
                ridership_lag_17 as ridership_lag_18,
                ridership_lag_18 as ridership_lag_19,
                ridership_lag_19 as ridership_lag_20,
                ridership_lag_20 as ridership_lag_21,
                ridership_lag_21 as ridership_lag_22,
                ridership_lag_22 as ridership_lag_23,
                ridership_lag_23 as ridership_lag_24
            FROM `lively-encoder-448916-d5.nyc_subway.recent_hours_data`
            WHERE transit_timestamp = (SELECT MAX(transit_timestamp) FROM `lively-encoder-448916-d5.nyc_subway.recent_hours_data`)
            """

In [None]:
# Step 3: Calculate moving averages
query_ma = """
            CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.hourly_forecast_moving_avg`
            AS (
            WITH ridership_moving_avg AS (
                SELECT
                    transit_timestamp,
                    transit_mode_index,
                    station_complex_index,
                    borough_index,
                    payment_method_index,
                    -- 7-day moving average
                    AVG(prediction) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS ridership_7d_mv,
                    AVG(hour_of_day) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS hour_of_day_7d_mv,
                    AVG(day_of_week) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS day_of_week_7d_mv,
                    AVG(week_of_month) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS week_of_month_7d_mv,
                    AVG(hour_of_day_sin) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS hour_of_day_sin_7d_mv,
                        AVG(hour_of_day_cos) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS hour_of_day_cos_7d_mv,
                    AVG(day_of_week_sin) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS day_of_week_sin_7d_mv,
                    AVG(day_of_week_cos) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS day_of_week_cos_7d_mv,
                    AVG(week_of_month_sin) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS week_of_month_sin_7d_mv,
                    AVG(week_of_month_cos) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 6 PRECEDING AND 0 PRECEDING
                    ) AS week_of_month_cos_7d_mv,
                    -- 30-day moving average
                    AVG(prediction) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS ridership_30d_mv,
                    AVG(hour_of_day) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS hour_of_day_30d_mv,
                    AVG(day_of_week) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS day_of_week_30d_mv,
                    AVG(week_of_month) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS week_of_month_30d_mv,
                    AVG(hour_of_day_sin) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS hour_of_day_sin_30d_mv,
                    AVG(hour_of_day_cos) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS hour_of_day_cos_30d_mv,
                    AVG(day_of_week_sin) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS day_of_week_sin_30d_mv,
                    AVG(day_of_week_cos) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS day_of_week_cos_30d_mv,
                    AVG(week_of_month_sin) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS week_of_month_sin_30d_mv,
                    AVG(week_of_month_cos) OVER (
                        PARTITION BY station_complex_index, transit_mode_index, borough_index, payment_method_index
                        ORDER BY transit_timestamp
                        ROWS BETWEEN 29 PRECEDING AND 0 PRECEDING
                    ) AS week_of_month_cos_30d_mv
                FROM `lively-encoder-448916-d5.nyc_subway.recent_hours_data`
                WHERE transit_timestamp >=  (SELECT TIMESTAMP_SUB(MAX(transit_timestamp), INTERVAL 29 HOUR) 
                                            FROM `lively-encoder-448916-d5.nyc_subway.recent_hours_data`) 
                AND transit_timestamp <= (SELECT MAX(transit_timestamp) 
                                            FROM `lively-encoder-448916-d5.nyc_subway.recent_hours_data`)
            )
            SELECT
                    (SELECT TIMESTAMP_ADD(MAX(transit_timestamp), INTERVAL 1 HOUR)
                    FROM `lively-encoder-448916-d5.nyc_subway.recent_hours_data`) AS next_transit_timestamp,
                *
            FROM
                ridership_moving_avg
            WHERE transit_timestamp = (SELECT MAX(transit_timestamp) FROM `lively-encoder-448916-d5.nyc_subway.recent_hours_data`)
            ORDER BY transit_timestamp DESC
            );
            """

In [None]:
# Step 4: Combine lags and moving averages to create input table for prediction
query_input = """
                CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.hourly_forecast_input`
                AS (
                SELECT
                    l.transit_timestamp,
                    l.transit_mode_index,
                    l.station_complex_index,
                    l.borough_index,
                    l.payment_method_index,
                    EXTRACT(HOUR FROM l.transit_timestamp) AS hour_of_day,
                    SIN(2 * ACOS(-1) * EXTRACT(HOUR FROM l.transit_timestamp) / 24) AS hour_of_day_sin,
                    COS(2 * ACOS(-1) * EXTRACT(HOUR FROM l.transit_timestamp) / 24) AS hour_of_day_cos,
                    EXTRACT(DAYOFWEEK FROM l.transit_timestamp) AS day_of_week,
                    SIN(2 * ACOS(-1) * EXTRACT(DAYOFWEEK FROM l.transit_timestamp) / 7) AS day_of_week_sin,
                    COS(2 * ACOS(-1) * EXTRACT(DAYOFWEEK FROM  l.transit_timestamp) / 7) AS day_of_week_cos,
                    CEIL(EXTRACT(DAY FROM  l.transit_timestamp) / 7) AS week_of_month,
                    SIN(2 * ACOS(-1) * CEIL(EXTRACT(DAY FROM  l.transit_timestamp) / 7) / 5) AS week_of_month_sin,
                    COS(2 * ACOS(-1) * CEIL(EXTRACT(DAY FROM  l.transit_timestamp) / 7) / 5) AS week_of_month_cos,
                    l.ridership_lag_1, l.ridership_lag_2, l.ridership_lag_3, l.ridership_lag_4, l.ridership_lag_5,
                    l.ridership_lag_6, l.ridership_lag_7, l.ridership_lag_8, l.ridership_lag_9, l.ridership_lag_10,
                    l.ridership_lag_11, l.ridership_lag_12, l.ridership_lag_13, l.ridership_lag_14, l.ridership_lag_15,
                    l.ridership_lag_16, l.ridership_lag_17, l.ridership_lag_18, l.ridership_lag_19, l.ridership_lag_20,
                    l.ridership_lag_21, l.ridership_lag_22, l.ridership_lag_23, l.ridership_lag_24,
                    mv.ridership_7d_mv,
                    mv.hour_of_day_7d_mv, mv.day_of_week_7d_mv, mv.week_of_month_7d_mv,
                    mv.hour_of_day_sin_7d_mv, mv.hour_of_day_cos_7d_mv,
                    mv.day_of_week_sin_7d_mv, mv.day_of_week_cos_7d_mv,
                    mv.week_of_month_sin_7d_mv, mv.week_of_month_cos_7d_mv,
                    mv.ridership_30d_mv,
                    mv.hour_of_day_30d_mv, mv.day_of_week_30d_mv, mv.week_of_month_30d_mv,
                    mv.hour_of_day_sin_30d_mv, mv.hour_of_day_cos_30d_mv,
                    mv.day_of_week_sin_30d_mv, mv.day_of_week_cos_30d_mv,
                    mv.week_of_month_sin_30d_mv, mv.week_of_month_cos_30d_mv
                    FROM `lively-encoder-448916-d5.nyc_subway.dec_31_lags` l
                    INNER JOIN `lively-encoder-448916-d5.nyc_subway.dec_31_ma` mv
                    ON l.transit_mode_index = mv.transit_mode_index
                    AND l.station_complex_index = mv.station_complex_index
                    AND l.borough_index = mv.borough_index
                    AND l.payment_method_index = mv.payment_method_index
            );
        """

In [None]:
# Step 6: Combine predictions (for next day's forecast) with input data (for features)
query_buffer = """
            CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.hourly_forecast_buffer` AS
                SELECT
                  -- Columns from input df
                  i.transit_timestamp,
                  i.transit_mode_index,
                  i.station_complex_index,
                  i.borough_index,
                  i.payment_method_index,
                  -- Column from output df
                  o.prediction,
                  -- Other columns from input df
                  i.hour_of_day,
                  i.hour_of_day_sin,
                  i.hour_of_day_cos,
                  i.day_of_week,
                  i.day_of_week_sin,
                  i.day_of_week_cos,
                  i.week_of_month,
                  i.week_of_month_sin,
                  i.week_of_month_cos,
                  i.ridership_lag_1,
                  i.ridership_lag_2,
                  i.ridership_lag_3,
                  i.ridership_lag_4,
                  i.ridership_lag_5,
                  i.ridership_lag_6,
                  i.ridership_lag_7,
                  i.ridership_lag_8,
                  i.ridership_lag_9,
                  i.ridership_lag_10,
                  i.ridership_lag_11,
                  i.ridership_lag_12,
                  i.ridership_lag_13,
                  i.ridership_lag_14,
                  i.ridership_lag_15,
                  i.ridership_lag_16,
                  i.ridership_lag_17,
                  i.ridership_lag_18,
                  i.ridership_lag_19,
                  i.ridership_lag_20,
                  i.ridership_lag_21,
                  i.ridership_lag_22,
                  i.ridership_lag_23,
                  i.ridership_lag_24,
                  i.ridership_7d_mv,
                  i.hour_of_day_7d_mv,
                  i.day_of_week_7d_mv,
                  i.week_of_month_7d_mv,
                  i.hour_of_day_sin_7d_mv,
                  i.hour_of_day_cos_7d_mv,
                  i.day_of_week_sin_7d_mv,
                  i.day_of_week_cos_7d_mv,
                  i.week_of_month_sin_7d_mv,
                  i.week_of_month_cos_7d_mv,
                  i.ridership_30d_mv,
                  i.hour_of_day_30d_mv,
                  i.day_of_week_30d_mv,
                  i.week_of_month_30d_mv,
                  i.hour_of_day_sin_30d_mv,
                  i.hour_of_day_cos_30d_mv,
                  i.day_of_week_sin_30d_mv,
                  i.day_of_week_cos_30d_mv,
                  i.week_of_month_sin_30d_mv,
                  i.week_of_month_cos_30d_mv

                FROM `lively-encoder-448916-d5.nyc_subway.hourly_forecast_input` AS i
                INNER JOIN `lively-encoder-448916-d5.nyc_subway.hourly_forecast_output` AS o
                  ON i.transit_timestamp = o.transit_timestamp
                  AND i.transit_mode_index = o.transit_mode_index
                  AND i.station_complex_index = o.station_complex_index
                  AND i.borough_index = o.borough_index
                  AND i.payment_method_index = o.payment_method_index
            """

In [None]:
# Step 7: Join predictions with recent data
query_union = """
            CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.recent_hours_data` AS
            SELECT * FROM `lively-encoder-448916-d5.nyc_subway.recent_hours_data`
            UNION ALL
            SELECT * FROM `lively-encoder-448916-d5.nyc_subway.hourly_forecast_buffer`;
            """

# -- -- Query to DROP DUPLICATES:
# -- CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.recent_hours_data` AS
# -- SELECT DISTINCT *
# -- FROM `lively-encoder-448916-d5.nyc_subway.recent_hours_data`;

In [None]:
# Create recent_dates_df
client.query(query_recent_data).result()

In [None]:
import time

for i in range(1, 2):
    
    # Add lags, mv, combine for input
    client.query(query_lags).result()
    client.query(query_ma).result()
    client.query(query_input).result()
    time.sleep(4)
    
    # Read input data from BigQuery table into a Spark DataFrame
    input_hourly_df = spark.read \
        .format("bigquery") \
        .option("table", "lively-encoder-448916-d5.nyc_subway.hourly_forecast_input") \
        .option("parentProject", "lively-encoder-448916-d5") \
        .load()

    # Extract the first value of the `df_transit_timestamp` column
    df_transit_timestamp = input_hourly_df.select("transit_timestamp").first()["transit_timestamp"]
    print(df_transit_timestamp)

    # Drop the `df_transit_timestamp` column from the DataFrame before inputting to the model
    input_hourly_df = input_hourly_df.drop("transit_timestamp")

    # Get predictions
    output_predictions = model.transform(input_hourly_df)
    output_predictions = output_predictions.withColumn("transit_timestamp", to_timestamp(lit(df_transit_timestamp)))
    output_predictions = output_predictions.withColumn("prediction", col("prediction").cast(IntegerType()))
    output_predictions = output_predictions.select('transit_timestamp', 'transit_mode_index', 'station_complex_index', 'borough_index', 'payment_method_index', 'prediction')

    # Write predictions
    output_predictions.write.format("bigquery") \
        .option("table", "lively-encoder-448916-d5.nyc_subway.hourly_forecast_output") \
        .option("temporaryGcsBucket", "temp_dec_forecast_bucket") \
        .mode("overwrite") \
        .save()
    
    time.sleep(4)
    # Update recent_dates_df
    client.query(query_buffer).result()
    client.query(query_union).result()
    
    time.sleep(1)
    
    

In [None]:
query_forecast = """
                CREATE OR REPLACE TABLE `lively-encoder-448916-d5.nyc_subway.hour_forecast` AS
                SELECT * FROM `lively-encoder-448916-d5.nyc_subway.recent_hours_data`
                WHERE transit_timestamp > (SELECT MAX(transit_timestamp) from `lively-encoder-448916-d5.nyc_subway.hour_model`)
                ORDER BY transit_timestamp DESC;
                """
client.query(query_forecast).result()

In [None]:
# Read forecasted data from BigQuery table into a Spark DataFrame
hour_forecast = spark.read \
                    .format("bigquery") \
                    .option("table", "lively-encoder-448916-d5.nyc_subway.hour_forecast") \
                    .option("parentProject", "lively-encoder-448916-d5") \
                    .load()

In [None]:
from pyspark.sql.functions import col, lit, create_map, round, to_date

# Map encoded categorical columns to their original values
payment_method_mapping = {1: 'metrocard', 2: 'omny'}
transit_mode_mapping = {1: 'staten_island_railway', 2: 'subway', 3: 'tram'}
borough_mapping = {1: 'Bronx', 2: 'Brooklyn', 3: 'Manhattan', 4: 'Queens', 5: 'Staten Island'}

station_df = spark.read.csv("gs://bucket_jars/station_complex_csv/station_complex.csv", header=True)
station_df = station_df.withColumn("station_complex_index", col("station_complex_index").cast("int"))
station_df = station_df.select("station_complex_index", "station_complex")
# station_df.printSchema()

station_map = station_df.rdd.collectAsMap()
# print(station_mapping)
station_mapping = [
    item for pair in station_map.items() for item in (lit(pair[0]), lit(pair[1]))
]

# Create mapping expressions
payment_method_expr = create_map([lit(k) for pair in payment_method_mapping.items() for k in pair])
transit_mode_expr = create_map([lit(k) for pair in transit_mode_mapping.items() for k in pair])
borough_expr = create_map([lit(k) for pair in borough_mapping.items() for k in pair])
station_complex_expr = create_map(*station_mapping)

# Apply the mappings
hours_forecasted = hour_forecast.withColumn(
                        "payment_method", 
                        payment_method_expr[col("payment_method_index").cast("int")]
                    ).withColumn(
                        "transit_mode", 
                        transit_mode_expr[col("transit_mode_index").cast("int")]
                    ).withColumn(
                        "borough", 
                        borough_expr[col("borough_index").cast("int")]
                    ).withColumn(
                        "station_complex", 
                        station_complex_expr[col("station_complex_index").cast("int")]
                    )

In [None]:
hours_forecasted = hours_forecasted.select('transit_timestamp', 'transit_mode', 'station_complex', 'borough', 'payment_method', 'prediction')
# hours_forecasted = hours_forecasted.withColumn("prediction", round("prediction", 0))
# hours_forecasted = hours_forecasted.withColumn("prediction", col("prediction").cast("integer"))
# hours_forecasted.show(5, truncate=False)

In [None]:
# Write forecast
hours_forecasted.write.format("bigquery") \
    .option("table", "lively-encoder-448916-d5.nyc_subway.hour_forecast") \
    .option("temporaryGcsBucket", "temp_nyc_bucket_for_bq") \
    .mode("overwrite") \
    .save()