# Imports

In [0]:

import pyspark.sql.functions as F
from pyspark.sql.functions import col,isnan, when, count, concat_ws, countDistinct, collect_set, rank, window, avg, hour, udf, isnan, pandas_udf, to_timestamp, lit, PandasUDFType
import matplotlib.pyplot as plt
import pandas as pd
import re
import pytz
from datetime import datetime, timedelta, time
import numpy as np
from pyspark.sql import types

from pyspark.sql.types import StructType, StructField, StringType, IntegerType, StructType, DoubleType, LongType

from pyspark.sql import Window
import seaborn as sns

In [0]:
!pip install python-geohash
import geohash
from geohash import bbox

In [0]:
df=spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat.parquet/")

In [0]:
team_BASE_DIR = f"dbfs:/student-groups/Group_4_1"
spark.sparkContext.setCheckpointDir(f"{team_BASE_DIR}/interim")

#Weather Cleaning


In [0]:


df_interpolate = (
    df
    .withColumn(
        "origin_HourlyWindSpeed",
        F.when(
            F.col("origin_HourlyWindSpeed").isNull(),
            # Extract sustained wind speed from METAR groups
            F.regexp_extract(
                F.col("origin_REM"),
                r'\b(\d{3})(\d{2,3})(?:G(\d{2,3}))?KT\b',  # Regex pattern
                2  # Capture group for sustained wind speed
            ).cast("int")
        ).otherwise(F.col("origin_HourlyWindSpeed"))
    )
    .withColumn(
        "origin_HourlyWindGustSpeed",
        F.when(
            F.col("origin_HourlyWindGustSpeed").isNull(),
            F.greatest(
                # Regular wind gust (G group)
                F.regexp_extract(
                    F.col("origin_REM"),
                    r'\b(\d{3})(\d{2,3})(?:G(\d{2,3}))?KT\b',
                    3
                ).cast("int"),
                # Peak wind gust (PK WND group)
                F.regexp_extract(
                    F.col("origin_REM"),
                    r'PK WND (\d{3})(\d{2,3})/(\d{4})',  # PK WND pattern
                    2  # Capture group for peak wind speed
                ).cast("int")
            )
        ).otherwise(F.col("origin_HourlyWindGustSpeed"))
    )
)


df_interpolate = (df_interpolate \
    .withColumn(
        'origin_HourlyPrecipitation',
        F.when(
            (F.col("origin_HourlyPrecipitation").isNull()) | (F.col("origin_HourlyPrecipitation") == '*'),
            (F.regexp_extract(F.col("origin_REM"), r" P(\d+)", 1).cast("int") * 0.01) # hundredths of inch kept in "remarks" section
        ).otherwise(F.col("origin_HourlyPrecipitation"))
    ) \
    .withColumn('origin_HourlyPrecipitation', F.regexp_replace('origin_HourlyPrecipitation', 'T', '0.01')) \
    .withColumn(
        'origin_HourlyPrecipitation',
        F.regexp_extract('origin_HourlyPrecipitation', r"[0-9]+(\.[0-9]+)?", 0) # Match digits
    ) \
    .withColumn('origin_HourlyPrecipitation', F.col('origin_HourlyPrecipitation').cast(DoubleType())))

In [0]:


def encode_geohash(precision: int):
    @pandas_udf("string")
    def encode(latitudes: pd.Series, longitudes: pd.Series) -> pd.Series:
        def safe_encode(lat, lon):
            try:
                return geohash.encode(lat, lon, precision)
            except Exception:
                return None
        return latitudes.combine(longitudes, safe_encode)
    return encode

geohash_udf = encode_geohash(precision=2)
df_interpolate = df_interpolate.withColumn('geohash', geohash_udf(F.col('origin_LATITUDE'), F.col('origin_LONGITUDE')))
display(df_interpolate)

In [0]:
def coalesce_within_geohash(
    df, 
    target_col, 
    geohash_col="geohash", 
    dt_col="sched_depart_utc", 
    window_size=6
):
    """Fill nulls in `target_col` using the latest non-null value from the same geohash."""
    
    window_spec = (
        Window.partitionBy(geohash_col)
              .orderBy(F.col(dt_col).cast("long"))
              .rowsBetween(-window_size, 0)
    )
    
    return df.withColumn(
        target_col,
        F.last(target_col, ignorenulls=True).over(window_spec)
    )

In [0]:
df_interpolated = df_interpolate.withColumns(
    {"origin_HourlyPrecipitation": df_interpolate["origin_HourlyPrecipitation"].cast("float"),
     "origin_HourlyWindGustSpeed": df_interpolate["origin_HourlyWindGustSpeed"].cast("float"),
     "origin_HourlyWindSpeed": df_interpolate["origin_HourlyWindSpeed"].cast("float"),
     "origin_HourlyDewPointTemperature": df_interpolate["origin_HourlyDewPointTemperature"].cast("float"),
     "origin_HourlyDryBulbTemperature": df_interpolate["origin_HourlyDryBulbTemperature"].cast("float"),
     "origin_HourlyPressureChange": df_interpolate["origin_HourlyPressureChange"].cast("float"),
     "origin_HourlyRelativeHumidity": df_interpolate["origin_HourlyRelativeHumidity"].cast("float"),
     "origin_HourlyWetBulbTemperature": df_interpolate["origin_HourlyWetBulbTemperature"].cast("float"),
     "origin_HourlyVisibility": df_interpolate["origin_HourlyVisibility"].cast("float")
     
     
     })

In [0]:
columns_to_fill = ['origin_HourlyVisibility','origin_HourlyWindSpeed','origin_HourlyDewPointTemperature','origin_HourlyDryBulbTemperature','origin_HourlyPressureChange','origin_HourlyRelativeHumidity','origin_HourlyWetBulbTemperature','origin_HourlyPrecipitation','origin_HourlyWindGustSpeed']

for col in columns_to_fill:
    df_interpolated = coalesce_within_geohash(df_interpolated, col)



In [0]:

null_counts = df_interpolated.filter(F.col('origin_LATITUDE').isNotNull()).select(
    [F.count(F.when(F.col(c).isNull() | F.isnan(c), c)).alias(c) for c in columns_to_fill]
)

null_counts_orig = df_interpolate.filter(F.col('origin_LATITUDE').isNotNull()).select(
    [F.count(F.when(F.col(c).isNull() | F.isnan(c), c)).alias(c) for c in columns_to_fill]
)
display(null_counts_orig.unionByName(null_counts)) #null counts where we have non null loc (obviously, null loc is impossible to fill in by geohash)

In [0]:


@pandas_udf(DoubleType())
def exponential_smoothing_pandas(values: pd.Series) -> pd.Series:
    """Vectorized UDF for exponential smoothing."""
    if values.empty or not pd.api.types.is_numeric_dtype(values):
        return pd.Series([0.0] * len(values))  # Handle edge cases
    return values.ewm(alpha=0.5, ignore_na=True).mean()

def smooth_column_optimized(
    df, 
    col_name, 
    station_col="origin_STATION", 
    dt_col="sched_depart_date_time", 
    window_size=6
):
    """Applies exponential smoothing to remaining nulls"""
    
    # 1. Cast to numeric type and filter nulls
    df = (
        df.withColumn(col_name, F.col(col_name).cast(DoubleType()))
    )
    
    # 2. Define window to collect non-null values
    window_spec = (
        Window.partitionBy(station_col)
              .orderBy(F.col(dt_col).cast("long"))
              .rowsBetween(-window_size, 0)
    )
    
    # 3. Collect ONLY non-null values within the window
    df = df.withColumn(
        "non_null_values",
        F.collect_list(col_name).over(window_spec)
    )
    
    # 4. Apply vectorized UDF and fill nulls
    return (
        df.withColumn("smoothed", exponential_smoothing_pandas("non_null_values"))
          .withColumn(
              col_name, 
              F.coalesce(F.col(col_name), F.col("smoothed"))
          )
          .drop("non_null_values", "smoothed")
    )


In [0]:
columns_to_fill = ['origin_HourlyWindSpeed','origin_HourlyDewPointTemperature','origin_HourlyDryBulbTemperature','origin_HourlyPressureChange','origin_HourlyRelativeHumidity','origin_HourlyWetBulbTemperature','origin_HourlyPrecipitation','origin_HourlyWindGustSpeed','origin_HourlyVisibility']

for col in columns_to_fill:
    df_interpolated = smooth_column_optimized(df_interpolated, col)



In [0]:
filled_cols = ['origin_HourlyVisibility','origin_HourlyWindSpeed','origin_HourlyDewPointTemperature','origin_HourlyDryBulbTemperature','origin_HourlyPressureChange','origin_HourlyRelativeHumidity','origin_HourlyWetBulbTemperature','origin_HourlyPrecipitation','origin_HourlyWindGustSpeed']

null_counts_ema = df_interpolated.filter(F.col('origin_LATITUDE').isNotNull()).select(
    [F.count(F.when(F.col(c).isNull() | F.isnan(c), c)).alias(c) for c in filled_cols]
)


In [0]:
display(null_counts_orig.unionByName(null_counts).unionByName(null_counts_ema))

In [0]:
df_interpolated.write.mode("overwrite").parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned.parquet/")

#Lags 

## Create features

In [0]:
df = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned.parquet/")

In [0]:
df_small = df.filter(F.col('TAIL_NUM').isNotNull()).filter(F.col('sched_depart_utc')>'2019-04-01').orderBy('TAIL_NUM','sched_depart_utc')

In [0]:
display(df_small.select('TAIL_NUM','ORIGIN','DEST','sched_depart_utc','CANCELLED','two_hours_prior_depart_UTC','priorflight_origin','priorflight_dest', 'priorflight_deptime_calc','priorflight_deptime_calc','priorflight_cancelled_true'))

In [0]:
display(df.filter(F.col('TAIL_NUM').isNotNull()).select('TAIL_NUM','ORIGIN','DEST','sched_depart_utc','CANCELLED','two_hours_prior_depart_UTC','priorflight_origin','priorflight_dest', 'priorflight_deptime_calc','priorflight_deptime_calc','priorflight_cancelled_true').orderBy('TAIL_NUM','sched_depart_utc'))

In [0]:
display(df.filter(F.col('TAIL_NUM')=='N102UW').select('ORIGIN','DEST','sched_depart_utc','CANCELLED','two_hours_prior_depart_UTC','priorflight_origin','priorflight_dest', 'priorflight_deptime_calc','priorflight_deptime_calc','priorflight_cancelled_true').orderBy('sched_depart_utc'))

In [0]:
display(df.filter((F.col('CANCELLED') == 1) & (F.col('priorflight_cancelled_true') == 1)).filter(F.col('TAIL_NUM').isNotNull()).select('TAIL_NUM','ORIGIN','DEST','sched_depart_utc','CANCELLED','two_hours_prior_depart_UTC','priorflight_origin','priorflight_dest', 'priorflight_deptime_calc','priorflight_deptime_calc','priorflight_cancelled_true').orderBy('TAIL_NUM','sched_depart_utc'))

In [0]:
"2019-10-27T19:46:00.000+00:00".tzinfo

In [0]:
def to_utc(yyyymmdd, dep_hhmm, arr_hhmm, dep_tz, arr_tz, flight_dur):
    """
    Create UTC timestamp from flights table columns
    yyyymmdd = FL_DATE
    dep_hhmm = CRS_DEP_TIME
    arr_hhmm = CRS_ARR_TIME
    dep_tz = origin_timezone
    arr_tz = dest_timezone
    flight_dur = CRS_ELAPSED_TIME (for sanity check of arrival time)

    Returns UTC time stamp, (cast to string)
    """

    dep_hhmm = int(dep_hhmm)
    arr_hhmm = int(arr_hhmm)

    yyyy,MM,dd = yyyymmdd.split('-')
    yyyy = int(yyyy) # get year
    MM = int(MM) # get month
    dd = int(dd) # get day

    dep_hh = dep_hhmm//100 # get hour
    dep_mm = dep_hhmm%100 # get minute
    if dep_hh == 24:
        dep_hh = 0
        dep_shift = True
    else:
        dep_shift = False

    arr_hh = arr_hhmm//100 # get hour
    arr_mm = arr_hhmm%100
    if arr_hh == 24:
        arr_hh = 0
        arr_shift = True
    else:
        arr_shift = False

    # create datetime variable for departure
    dt_dep = datetime(yyyy,MM,dd,dep_hh,dep_mm)
    if dep_shift:
        dt_dep += timedelta(days=1)
    # apply local time zone
    dep_local = pytz.timezone(dep_tz).localize(dt_dep)
    # convert to UTC
    dep_utc = dep_local.astimezone(pytz.utc)

    # create datetime variable for arrival
    dt_arr = datetime(yyyy,MM,dd,arr_hh,arr_mm)
    if arr_shift:
        dt_arr += timedelta(days=1)
    # apply local time zone
    arr_local = pytz.timezone(arr_tz).localize(dt_arr)
    # convert to UTC
    arr_utc = arr_local.astimezone(pytz.utc)

    if dep_utc > arr_utc:
        arr_utc += timedelta(days=1)

    # # sanity check
    # arr_utc_SC = dep_utc + timedelta(minutes=flight_dur)

    dt_format = "%Y-%m-%dT%H:%M:%S"

    # return UTC datetime, cast to string
    # return (dep_utc.strftime(dt_format), arr_utc.strftime(dt_format), arr_utc_SC.strftime(dt_format))
    return (dep_utc.strftime(dt_format), arr_utc.strftime(dt_format))

schema = StructType([
    StructField("dep_datetime", StringType(), False),
    StructField("arr_datetime", StringType(), False),
])

dt_udf = udf(to_utc, schema)

out = df.withColumn('processed', 
                                 dt_udf(F.col("FL_DATE"), 
                                        F.col("CRS_DEP_TIME"), 
                                        F.col("CRS_ARR_TIME"), 
                                        F.col("origin_timezone"), 
                                        F.col("dest_timezone"), 
                                        F.col("CRS_ELAPSED_TIME"))
                                 ).cache()

cols = [c for c in out.columns if c != "processed"]
cols += ["processed.dep_datetime","processed.arr_datetime"]
out = out.select(cols)

display(out)

In [0]:

df_dated = (df
.withColumn("FL_YMD", F.col("FL_DATE").substr(0,10))
.withColumn(
    "base_date", F.to_date(F.col("FL_YMD"), "yyyy-MM-dd")
).withColumn(
    "arr_time_adj",
    F.when(F.col("CRS_ARR_TIME") == 2400, F.lit("0000")).otherwise(F.col("CRS_ARR_TIME"))
).withColumn(
    "arr_time_str",
    F.lpad("arr_time_adj", 4, "0")
).withColumn(
    "arr_timestamp",
    F.concat(
        F.col("FL_YMD"),
        F.lit("T"),
        F.substring("arr_time_str", 1, 2),
        F.lit(":"),
        F.substring("arr_time_str", 3, 2),
        F.lit(":00")
    )
).withColumn(
    "local_datetime",
    F.to_timestamp("arr_timestamp", "yyyy-MM-dd'T'HH:mm:ss")
).withColumn(
    "arr_utc",
    F.when(
        F.col("CRS_ARR_TIME") == 2400,
        F.to_utc_timestamp(F.date_add("local_datetime", 1), F.col("dest_timezone"))
    ).otherwise(
        F.to_utc_timestamp("local_datetime", F.col("dest_timezone"))
    )
).withColumn(
    "sched_arr_utc",
    F.when(
        F.col("sched_depart_utc") > F.col("arr_utc"),
        F.date_add("arr_utc", 1)
    ).otherwise(F.col("arr_utc"))
).withColumn(
    "actual_depart_utc",
    F.to_utc_timestamp(F.col("sched_depart_utc"), F.col("origin_timezone"))

).select(
    *[c for c in df.columns],
    "sched_arr_utc"
))

In [0]:
df_dated = (df_dated
.withColumn("FL_YMD", F.col("FL_DATE").substr(0,10))
.withColumn(
    "base_date", F.to_date(F.col("FL_YMD"), "yyyy-MM-dd")
).withColumn(
    "arr_time_adj",
    F.when(F.col("ARR_TIME") == 2400, F.lit("0000")).otherwise(F.col("ARR_TIME"))
).withColumn(
    "arr_time_str",
    F.lpad("arr_time_adj", 4, "0")
).withColumn(
    "arr_timestamp",
    F.concat(
        F.col("FL_YMD"),
        F.lit("T"),
        F.substring("arr_time_str", 1, 2),
        F.lit(":"),
        F.substring("arr_time_str", 3, 2),
        F.lit(":00")
    )
).withColumn(
    "local_datetime",
    F.to_timestamp("arr_timestamp", "yyyy-MM-dd'T'HH:mm:ss")
).withColumn(
    "arr_utc",
    F.when(
        F.col("ARR_TIME") == 2400,
        F.to_utc_timestamp(F.date_add("local_datetime", 1), F.col("dest_timezone"))
    ).otherwise(
        F.to_utc_timestamp("local_datetime", F.col("dest_timezone"))
    )
).withColumn(
    "actual_arr_utc",
    F.when(
        F.col("actual_depart_utc") > F.col("arr_utc"),
        F.date_add("arr_utc", 1)
    ).otherwise(F.col("arr_utc"))
    )
)

In [0]:
df_dated = df_dated.drop('arr_time_adj').drop('arr_time_str').drop('arr_timestamp').drop('arr_utc').drop('local_datetime')

In [0]:
df.filter(F.col('origin_LATITUDE').isNull()).count()

In [0]:


def add_lags_optimized(df):
    # Define windows once
    aircraft_window = Window.partitionBy("TAIL_NUM").orderBy('sched_depart_utc')

    # route_window = Window.partitionBy("ORIGIN", "DEST").orderBy("sched_depart_utc").rowsBetween(-10, -1)

    WhenConditions = (
        (F.col("ORIGIN") == F.col("priorflight_dest")) & 
        (F.col("priorflight_sched_deptime") >= F.col("twentysix_hours_prior_depart_UTC"))
    )

    # Precompute all lagged columns in single pass
    lagged_cols = [
        F.lag("CANCELLED").over(aircraft_window).alias("priorflight_cancelled_true"),
        F.lag("ORIGIN").over(aircraft_window).alias("priorflight_origin"),
        F.lag("DEST").over(aircraft_window).alias("priorflight_dest"),
        F.lag("sched_depart_utc").over(aircraft_window).alias("priorflight_sched_deptime"),
        F.lag("actual_depart_utc").over(aircraft_window).alias("priorflight_true_deptime"),
        F.lag("CRS_ELAPSED_TIME").over(aircraft_window).alias("priorflight_sched_elapsed"),
        F.lag("ACTUAL_ELAPSED_TIME").over(aircraft_window).alias("priorflight_true_elapsed"),
        F.lag("DEP_DELAY").over(aircraft_window).alias("priorflight_true_depdelay"),
        F.lag("sched_arr_utc").over(aircraft_window).alias("priorflight_sched_arrtime"),
        F.lag("actual_arr_UTC").over(aircraft_window).alias("priorflight_true_arrtime")
    ]

    valid_prior = WhenConditions & (F.col("priorflight_cancelled_true") == 0)


    # Base transformations
    base_df = (df
        .withColumn("twentysix_hours_prior_depart_UTC", 
                   (F.col("two_hours_prior_depart_UTC") - F.expr("INTERVAL 24 HOURS")).cast("timestamp"))
        .select("*", *lagged_cols)
    )    

    # Core calculations
    result_df = (base_df
        .withColumn("priorflight_sched_elapsed",
            F.when(valid_prior,
                F.expr("INTERVAL 1 MINUTE") * F.col("priorflight_sched_elapsed")
            )
        )

        .withColumn("priorflight_true_elapsed",
                F.when(valid_prior,
                    F.expr("INTERVAL 1 MINUTE")* F.col("priorflight_true_elapsed")
                    ) 
                    
                )

        
        .withColumn("priorflight_depdelay_calc",
            F.when(valid_prior, F.col("priorflight_true_depdelay")).otherwise(F.lit(None))
        )

        .withColumn("priorflight_isdeparted",
            F.when(
                (F.col("priorflight_true_deptime") <= F.col("two_hours_prior_depart_UTC")) &
                valid_prior, 1
            ).otherwise(0)
        )

        .withColumn("priorflight_depdelay_calc",
            F.when(
                (F.col("priorflight_true_deptime") <= F.col("two_hours_prior_depart_UTC")) & valid_prior,
                F.col("priorflight_true_depdelay")
            ).when(
                (F.col("priorflight_sched_deptime") <= F.col("two_hours_prior_depart_UTC")) &
                (F.col("priorflight_true_deptime") > F.col("two_hours_prior_depart_UTC")) &
                valid_prior,
                (F.col("two_hours_prior_depart_UTC").cast('long') - 
                 F.col("priorflight_sched_deptime").cast('long')) / 60
            ).otherwise(F.lit(0.0)) #if not enough info, assume all is well - in line with other logic; edge cases handled later
        )
        .withColumn("priorflight_deptime_calc",  
            F.col("priorflight_sched_deptime") + 
            (F.expr("INTERVAL 1 MINUTE") * F.col("priorflight_depdelay_calc"))
        ) #non-valid prior depdelay calcs will get rewritten over in edge case handling
        
        .withColumn("priorflight_isdelayed_calc",
            F.when(
                (F.col("priorflight_depdelay_calc") >= 15) | 
                (F.col('priorflight_cancelled_true') == 1), 1
            ).otherwise(0)
        )

        .withColumn("priorflight_isarrived_calc",
            F.when(
                (F.col("priorflight_true_arrtime") <= F.col("two_hours_prior_depart_UTC")) &
                valid_prior, 1
            ).otherwise(0)
        )
        .withColumn("priorflight_arr_time_calc",
            F.when(
                F.col("priorflight_isarrived_calc") == 1,
                F.col("priorflight_true_arrtime")
            ).when(
                (F.col("priorflight_isarrived_calc") == 0) &
                (F.col("priorflight_true_deptime") <= F.col("two_hours_prior_depart_UTC")), 
                F.col("priorflight_true_deptime") + F.col("priorflight_sched_elapsed")
            ).otherwise(
                F.col("priorflight_deptime_calc") + F.col("priorflight_sched_elapsed")
            )
        )
        .withColumn("turnaround_time_calc",
            F.when(valid_prior,
                ((F.col("sched_depart_utc").cast("long") - 
                  F.col("priorflight_arr_time_calc").cast("long")) / 60).cast("double")
            ).otherwise(F.lit(None))
        )
        
    ).cache()

    return result_df

# # Execute pipeline
# result = add_lags_optimized(out)
# display(result)


In [0]:
team_BASE_DIR = f"dbfs:/student-groups/Group_4_1"

df = spark.read.parquet(f"{team_BASE_DIR}/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr.parquet")


In [0]:
aircraft_window = Window.partitionBy("TAIL_NUM").orderBy('sched_depart_utc')

    # route_window = Window.partitionBy("ORIGIN", "DEST").orderBy("sched_depart_utc").rowsBetween(-10, -1)

WhenConditions = (
        (F.col("ORIGIN") == F.col("priorflight_dest")) & 
        (F.col("priorflight_sched_deptime") >= F.col("twentysix_hours_prior_depart_UTC"))
    )


In [0]:
df_= df.withColumn("priorflight_carrier", F.lag("OP_UNIQUE_CARRIER").over(aircraft_window))

In [0]:
df_= df_.withColumn("priororigin_mean_dep_delay", F.lag("mean_dep_delay").over(aircraft_window))

In [0]:
df_= df_.withColumn("priororigin_type", F.lag("origin_type").over(aircraft_window))

In [0]:
df_.select('priororigin_mean_dep_delay').limit(100).summary().show()

In [0]:
df.select('route_risk').limit(100).summary().show()

In [0]:
df_.write.mode("overwrite").parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr_v2.parquet/")

## Don't run: imputation

In [0]:

WhenConditions = (
        (F.col("ORIGIN") == F.col("priorflight_dest")) & 
        (F.col("priorflight_sched_deptime") >= F.col("twentysix_hours_prior_depart_UTC"))
    )
valid_prior = WhenConditions & (F.col("priorflight_cancelled_true") == 0)

# partition by route (ORIGIN->DEST)
hours = lambda i: i * 3600
window_spec = Window.partitionBy(F.col("ORIGIN"),F.col("DEST"), F.col("FL_DATE")) \
    .orderBy(F.col("sched_depart_utc").cast("long")
             ) \
        .rangeBetween(-hours(48),0)
# we will eventually get just -4 to -2 hours, but using 0 in the window allows us to
# grab the utc-2 for the 0 hour offset case

df_routes = df_lags.repartition("ORIGIN", "DEST", "FL_DATE")


@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def mean_turnarounds_udf(turnarounds: pd.Series, 
                       act_dep_times: pd.Series, 
                       sched_dep_utc2: pd.Series) -> float:
    d = turnarounds[(act_dep_times < np.max(sched_dep_utc2))].astype(np.float)
    return np.nanmean(d)

# Apply the UDF over the window
df_lags_imputed = df_routes \
    .withColumn("mean_turnaround_calc", 
        F.when(~valid_prior,
        mean_turnarounds_udf(
                F.col("turnaround_time_calc"),
                F.col("actual_depart_utc"),
                F.col("two_hours_prior_depart_UTC")
            ).over(window_spec)).otherwise(F.col("turnaround_time_calc"))
        
        )

df_lags_imputed.cache()
display(df_lags_imputed)

In [0]:
display(df_lags_imputed.filter(F.col('turnaround_time_calc').isNull()).filter(F.col('mean_turnaround_calc').isNotNull()))

In [0]:
df_lags.write.mode("overwrite").parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned.parquet/")

In [0]:
df_lags_imputed

In [0]:
current = df_routes.alias("current")
prior = df_routes.alias("prior")

join_cond = (
    (F.col("current.ORIGIN") == F.col("prior.ORIGIN")) &
    (F.col("current.DEST") == F.col("prior.DEST")) &
    (F.col("current.FL_DATE") == F.col("prior.FL_DATE")) &
    (F.col("prior.actual_depart_utc") <= F.col("current.two_hours_prior_depart_UTC"))
)

result = (
    current.join(prior, join_cond, "left")
    .groupBy("current.*")
    .agg(F.avg("prior.turnaround_time_calc").alias("mean_turnaround_calc"))
    .withColumn(
        "turnaround_time_calc",
        F.when(
            ~valid_prior,  # Use your existing valid_prior condition
            F.col("mean_turnaround_calc")
        ).otherwise(F.col("turnaround_time_calc"))
    )
)

In [0]:
valid_prior = (
    (F.col("ORIGIN") == F.col("priorflight_dest")) &
    (F.col("priorflight_sched_deptime") >= F.col("twentysix_hours_prior_depart_UTC")) &
    (F.col("priorflight_cancelled_true") == 0)
)

hours = lambda i: i * 3600

# window to include prior 72 of flights on same route
window_spec = Window.partitionBy("ORIGIN", "DEST") \
                   .orderBy(F.col("sched_depart_utc").cast("long")) \
                   .rangeBetween(-hours(72),0)

# calculate conditional average
df_optimized = df_lags \
    .withColumn("valid_turnaround",
        F.when(
            valid_prior,
            F.col("turnaround_time_calc")
        ).otherwise(F.lit(None)) #only consider valid prior flights' turnaround times
    ) \
    .withColumn("mean_turnaround",
        F.avg("valid_turnaround").over(window_spec) #take mean over valid turnaround times
    ) \
    .withColumn("turnaround_time_calc",
        F.when(
            (~valid_prior | F.col('turnaround_time_calc').isNull()), 
               F.col("mean_turnaround")
               ) #replace flight turnarounds that have invalid priors to mean over past 72 hours of that route being flown
          .otherwise(F.col("turnaround_time_calc"))
    ).cache()

display(df_optimized)

In [0]:
valid_prior = (
    (F.col("ORIGIN") == F.col("priorflight_dest")) &
    (F.col("priorflight_sched_deptime") >= F.col("twentysix_hours_prior_depart_UTC")) &
    (F.col("priorflight_cancelled_true") == 0)
)

hours = lambda i: i * 3600

# window to include prior 72 of flights on same route
window_spec = Window.partitionBy("ORIGIN", "DEST") \
                   .orderBy(F.col("sched_depart_utc").cast("long")) \
                   .rangeBetween(-hours(72),0)

# calculate conditional average
df_optimized = df_lags \
    .withColumn("valid_turnaround",
        F.when(
            valid_prior,
            F.col("turnaround_time_calc")
        ).otherwise(F.lit(None)) #only consider valid prior flights' turnaround times
    ) \
    .withColumn("mean_turnaround",
        F.avg("valid_turnaround").over(window_spec) #take mean over valid turnaround times
    ) \
    .withColumn("turnaround_time_calc",
        F.when(
            (~valid_prior | F.col('turnaround_time_calc').isNull()), 
               F.col("mean_turnaround")
               ) #replace flight turnarounds that have invalid priors to mean over past 72 hours of that route being flown
          .otherwise(F.col("turnaround_time_calc"))
    ).cache()

display(df_optimized)

In [0]:
        # # Edge case handling - use Erica's code for: 
        # #1) revise a bit for turnaround time calc - avg over route prior flights
        # #2) use original for priorflight depdelay calc - lag mean delay (@prev origin)
        # #3) then update priorflight_isdelayed_calc
        # .withColumn("turnaround_time_calc",
        #     F.when(
        #         (~valid_prior),
        #         F.last("turnaround_time_calc", ignorenulls=True).over(route_window)
        #     ).otherwise(F.col("turnaround_time_calc"))
        # )
        # .withColumn("priorflight_depdelay_calc",
        #     F.when(
        #         (~valid_prior),
        #         F.last("priorflight_depdelay_calc", ignorenulls=True).over(route_window)
        #     ).otherwise(F.col("priorflight_depdelay_calc"))
        # )

        # .withColumn("priorflight_isdelayed_calc",
        #     F.when(
        #         (F.col("priorflight_depdelay_calc") >= 15) | 
        #         (F.col('priorflight_cancelled_true') == 1), 1
        #     ).otherwise(0)
        # )

# Graph

Brainstorming:

Pagerank
-  makes sense to do it within windows; airports most popular in the winter won't be the most popular in the summer, plus reasonable assumption for scheduling vs leakage 
- personalized pr - should be teleportation factor preference for possible dest wrt origin
  - convert to graphx first

In [0]:
from graphframes import GraphFrame


In [0]:
v = df_lags.select(
    F.col("ORIGIN").alias("id")
).union(
    df_lags.select(F.col("DEST").alias("id"))
).distinct()


e = df_lags.select(
    F.col("ORIGIN").alias("src"),
    F.col("DEST").alias("dst")
)

g = GraphFrame(v, e)


In [0]:
# Define fold date ranges based on erica's cv
folds = [
    {"fold": 'train_0', "date_min": "2014-12-31", "date_max": "2015-10-09"},
    {"fold": 'test_0', "date_min": "2015-10-09", "date_max": "2016-07-17"},
    {"fold": 'train_1', "date_min": "2015-08-14", "date_max": "2016-05-21"},
    {"fold": 'test_1', "date_min": "2016-05-21", "date_max": "2017-02-27"},
    {"fold": 'train_2', "date_min": "2016-03-27", "date_max": "2017-01-01"},
    {"fold": 'test_2', "date_min": "2017-01-01", "date_max": "2017-10-10"},
    {"fold": 'train_3', "date_min": "2016-11-08", "date_max": "2017-08-14"},
    {"fold": 'test_3', "date_min": "2017-08-14", "date_max": "2018-05-23"},
    {"fold": 'train_4', "date_min": "2017-06-22", "date_max": "2018-03-27"},
    {"fold": 'test_4', "date_min": "2018-03-27", "date_max": "2018-01-01"}
    ]

In [0]:
def run_pagerank_on_folds(df, date_column='sched_depart_utc', folds=folds):
    results = []
    
    for fold in folds:
        print(f"Processing fold {fold['fold']}")
        
        # Filter vertices for training period
        fold_df = df.filter(
            (F.col(date_column) >= fold['date_min']) & 
            (F.col(date_column) < fold['date_max'])
        )
        
        vertices_df = fold_df.select(F.col("ORIGIN").alias("id")
                 ).union(
                     df_lags.select(F.col("DEST").alias("id"))
                     ).distinct()
        
        edges_df = fold_df.select(
                F.col("ORIGIN").alias("src"),
                F.col("DEST").alias("dst")
            )
        
        # Create GraphFrame for this fold
        g = GraphFrame(vertices_df, edges_df)
        
        # Run PageRank
        pagerank_results = g.pageRank(resetProbability=.15,
                                      maxIter= 10)
        
        # Add fold information
        pagerank_results_with_fold = pagerank_results.vertices.withColumn("fold", F.lit(fold['fold']))
        
        results.append(pagerank_results_with_fold)

    # Combine all results
    all_results = results[0]
    for i in range(1, len(results)):
        all_results = all_results.union(results[i])
        
    return all_results



In [0]:
pr_df= run_pagerank_on_folds(df_lags)

In [0]:
display(pr_df)

In [0]:
results = g.pageRank(resetProbability=0.15, maxIter = 10)


results.vertices.select("id", "pagerank").show()
results.edges.select("src", "dst", "weight").show()


In [0]:
display(results.edges.select('src','dst','weight').distinct().orderBy(F.col('weight').desc()))

In [0]:
display(results.vertices.select("id", "pagerank").orderBy(F.col('pagerank').desc()))


In [0]:
# paths = g.find("(a)-[e1]->(b); (b)-[e2]->(c); (c)-[e3]->(a)")#need this to acocunt for time/immediate t+1 step though

In [0]:
pr_df.write.mode("overwrite").parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/pagerank.parquet/")

## Combine

In [0]:
df = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned.parquet/")

In [0]:
pr = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/pagerank.parquet/")

In [0]:
# Define fold date ranges based on erica's cv
folds = [
    {"fold": 'train_0', "date_min": "2014-12-31", "date_max": "2015-10-09"},
    {"fold": 'test_0', "date_min": "2015-10-09", "date_max": "2016-07-17"},
    {"fold": 'train_1', "date_min": "2015-08-14", "date_max": "2016-05-21"},
    {"fold": 'test_1', "date_min": "2016-05-21", "date_max": "2017-02-27"},
    {"fold": 'train_2', "date_min": "2016-03-27", "date_max": "2017-01-01"},
    {"fold": 'test_2', "date_min": "2017-01-01", "date_max": "2017-10-10"},
    {"fold": 'train_3', "date_min": "2016-11-08", "date_max": "2017-08-14"},
    {"fold": 'test_3', "date_min": "2017-08-14", "date_max": "2018-05-23"},
    {"fold": 'train_4', "date_min": "2017-06-22", "date_max": "2018-03-27"},
    {"fold": 'test_4', "date_min": "2018-03-27", "date_max": "2019-01-01"}
    ]

In [0]:
pr.filter(F.col('id')=='ATW').filter(F.col('fold')==0).show()

In [0]:
window_spec = Window.partitionBy('fold').orderBy('id')

pr_folds = pr.withColumn('row_num', F.row_number().over(window_spec))

In [0]:
window_spec = Window.partitionBy('fold', 'id').orderBy('id')

pr_folds = pr.withColumn('row_num', F.row_number().over(window_spec))

pr_folds = pr_folds.withColumn('row_num', F.when(F.col('row_num') > 1, 2).otherwise(F.col('row_num')))

In [0]:
df.filter(F.col('sched_depart_utc') > folds[0]['date_min']) \
  .filter(F.col('sched_depart_utc') <= folds[0]['date_max']) \
  .select('ORIGIN', 'DEST') \
  .select(F.explode(F.array('ORIGIN', 'DEST')).alias('airport')) \
  .distinct() \
  .count()

In [0]:
df.filter(F.col('sched_depart_utc') > folds[1]['date_min']) \
  .filter(F.col('sched_depart_utc') <= folds[1]['date_max']) \
  .select('ORIGIN', 'DEST') \
  .select(F.explode(F.array('ORIGIN', 'DEST')).alias('airport')) \
  .distinct() \
  .count()

In [0]:
folds_df = spark.createDataFrame(folds)

df_with_folds = df.join(
    folds_df,
    (df['sched_depart_utc'] > folds_df['date_min']) & (df['sched_depart_utc'] <= folds_df['date_max']),
    'inner'
).select(df['*'], folds_df['fold'])

display(df_with_folds)

In [0]:
df_with_folds.filter(F.col('fold').isNull()).count()

In [0]:
pr_folds = pr_folds.withColumn(
    'fold_label',
    when((F.col('fold') == 0) & (F.col('row_num') == 1), 'train_0')
    .when((F.col('fold') == 0) & (F.col('row_num') == 2), 'test_0')
    .when((F.col('fold') == 1) & (F.col('row_num') == 1), 'train_1')
    .when((F.col('fold') == 1) & (F.col('row_num') == 2), 'test_1')
    .when((F.col('fold') == 2) & (F.col('row_num') == 1), 'train_2')
    .when((F.col('fold') == 2) & (F.col('row_num') == 2), 'test_2')
    .when((F.col('fold') == 3) & (F.col('row_num') == 1), 'train_3')
    .when((F.col('fold') == 3) & (F.col('row_num') == 2), 'test_3')
    .when((F.col('fold') == 4) & (F.col('row_num') == 1), 'train_4')
    .when((F.col('fold') == 4) & (F.col('row_num') == 2), 'test_4')
)

display(pr_folds)

In [0]:
df_with_pr = df_with_folds.join(pr_folds.select('id','pagerank','fold_label'), 
                                (df_with_folds['fold'] == pr_folds['fold_label']) & 
                                (df_with_folds['ORIGIN'] == pr_folds['id']), 
                                'left')
df_with_pr.filter(F.col('fold').isNull()).count()

In [0]:
df_with_pr.groupBy('YEAR').count().show()

In [0]:
df_with_pr.filter(F.col('TAIL_NUM').isNotNull()).filter(F.col('pagerank').isNull()).count()

In [0]:
df_with_pr.write.mode('overwrite').parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr.parquet/")

In [0]:
display(df_with_pr.filter(F.col('TAIL_NUM').isNotNull()).filter(F.col('pagerank').isNull()))

In [0]:
display(df_with_pr.filter(F.col('TAIL_NUM').isNotNull()).filter(F.col('pagerank').isNull()))

## last folds

In [0]:
df_with_pr = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr.parquet/")

In [0]:
df = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned.parquet/")

In [0]:
window_spec = Window.partitionBy('fold', 'id').orderBy('id')

pr_folds = pr.withColumn('row_num', F.row_number().over(window_spec))

pr_folds = pr_folds.withColumn('row_num', F.when(F.col('row_num') > 1, 2).otherwise(F.col('row_num')))

pr_folds = pr_folds.withColumn(
    'fold_label',
    when((F.col('fold') == 0) & (F.col('row_num') == 1), 'train_0')
    .when((F.col('fold') == 0) & (F.col('row_num') == 2), 'test_0')
    .when((F.col('fold') == 1) & (F.col('row_num') == 1), 'train_1')
    .when((F.col('fold') == 1) & (F.col('row_num') == 2), 'test_1')
    .when((F.col('fold') == 2) & (F.col('row_num') == 1), 'train_2')
    .when((F.col('fold') == 2) & (F.col('row_num') == 2), 'test_2')
    .when((F.col('fold') == 3) & (F.col('row_num') == 1), 'train_3')
    .when((F.col('fold') == 3) & (F.col('row_num') == 2), 'test_3')
    .when((F.col('fold') == 4) & (F.col('row_num') == 1), 'train_4')
    .when((F.col('fold') == 4) & (F.col('row_num') == 2), 'test_4')
)

display(pr_folds)

In [0]:
pr_folds = pr_folds.filter(F.col('fold_label') != 'test_4')

In [0]:
display(pr_folds.filter(F.col('fold_label')=='train_4'))

In [0]:
pr = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/pagerank.parquet/")

In [0]:
def run_pagerank_on_folds(df, date_column='sched_depart_utc', folds=folds):
    results = []
    
    for fold in folds:
        print(f"Processing fold {fold['fold']}")
        
        # Filter vertices for training period
        fold_df = df.filter(
            (F.col(date_column) >= fold['date_min']) & 
            (F.col(date_column) < fold['date_max'])
        )
        
        vertices_df = fold_df.select(F.col("ORIGIN").alias("id")
                 ).union(
                     fold_df.select(F.col("DEST").alias("id"))
                     ).distinct()
        
        edges_df = fold_df.select(
                F.col("ORIGIN").alias("src"),
                F.col("DEST").alias("dst")
            )
        
        # Create GraphFrame for this fold
        g = GraphFrame(vertices_df, edges_df)
        
        # Run PageRank
        pagerank_results = g.pageRank(resetProbability=.15,
                                      maxIter= 10)
        
        # Add fold information
        pagerank_results_with_fold = pagerank_results.vertices.withColumn("fold", F.lit(fold['fold']))
        
        results.append(pagerank_results_with_fold)

    # Combine all results
    all_results = results[0]
    for i in range(1, len(results)):
        all_results = all_results.union(results[i])
        
    return all_results



In [0]:
pr_final = run_pagerank_on_folds(df, folds=blocks)

In [0]:
display(pr_folds)

In [0]:
all_pr_folds = pr_folds.drop('row_num').drop('fold').withColumnRenamed('fold_label','fold').unionByName(pr_final)

In [0]:
display(all_pr_folds)

In [0]:
display(pr_df_full)

In [0]:
folds_df = spark.createDataFrame(blocks)

df_with_folds = df.join(
    folds_df,
    (df['sched_depart_utc'] > folds_df['date_min']) & (df['sched_depart_utc'] <= folds_df['date_max']),
    'inner'
).select(df['*'], folds_df['fold'])

display(df_with_folds)

In [0]:
df_with_folds

In [0]:
df_with_folds_final = df_with_folds.filter(F.col('fold').contains('test'))

In [0]:
pr_final

In [0]:


df_with_pr_final = df_with_folds_final.join(
    pr_final.select('id', 'pagerank', 'fold'),
    (df_with_folds_final['fold'] == pr_final['fold']) & 
    (df_with_folds_final['ORIGIN'] == pr_final['id']),
    'left'
)


In [0]:
pr_df_full  = (df.join(all_pr_folds.filter(F.col('fold')=='train_0').withColumnRenamed('pagerank','train_0').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='test_0').withColumnRenamed('pagerank','test_0').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='train_1').withColumnRenamed('pagerank','train_1').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='test_1').withColumnRenamed('pagerank','test_1').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='train_2').withColumnRenamed('pagerank','train_2').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='test_2').withColumnRenamed('pagerank','test_2').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='train_3').withColumnRenamed('pagerank','train_3').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(dftest3.filter(F.col('fold') == 'test_3').withColumnRenamed('pagerank','test_3').drop('fold'), df['ORIGIN']==dftest3['id']).drop('id')
        .join(dftrain4.filter(F.col('fold')=='train_4').withColumnRenamed('pagerank','train_4').drop('fold'), df['ORIGIN']==dftrain4['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='test_4').withColumnRenamed('pagerank','test_4').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='test').withColumnRenamed('pagerank','test').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')

        
        )

In [0]:
df_with_pr_final.filter(F.col('pagerank').isNull()).count()

In [0]:
df_with_pr=df_with_pr.filter(F.col('fold')!='test_4')

In [0]:
df_with_pr_final.count()

In [0]:
df_with_pr = df_with_pr.dropDuplicates()

In [0]:
original_cols = df_with_pr_final.columns
new_columns = original_cols.copy()
new_columns[-1] = 'fold_label'  # Changes the last element from 'fold' to 'fold_label'
df_with_pr_final = df_with_pr_final.toDF(*new_columns)


In [0]:
df_pr = df_with_pr.unionByName(df_with_pr_final)

In [0]:
df_pr.groupBy('YEAR').count().show()

In [0]:
from pyspark.sql.functions import min, max

min_date = df_pr.select(min("FL_DATE")).first()[0]
max_date = df_pr.select(max("FL_DATE")).first()[0]

min_date, max_date

In [0]:
pr_df_full.write.mode("overwrite").parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr.parquet/")

In [0]:
pr_df_full

# Graph Features Exploration

## Community Detection

Idea: community detection. All airports are not concerned about all of each other, but would be concerned about the delays propagating within their communities. 

In [0]:
from graphframes import GraphFrame


In [0]:
df = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned.parquet/")

In [0]:
ydf=df.filter(F.col('YEAR')==2015)

In [0]:
v = ydf.select(
    F.col("ORIGIN").alias("id")
).union(
    ydf.select(F.col("DEST").alias("id"))
).distinct()


e = ydf.select(
    F.col("ORIGIN").alias("src"),
    F.col("DEST").alias("dst")
)

g = GraphFrame(v, e)


### Label Propagation

Tl;dr as the spark docs says, literally just groups stuff into 1 label for the most part

In [0]:
result = g.labelPropagation(maxIter=15)
result.groupBy('label').count().show()

In [0]:
e.count()

In [0]:
result.filter(F.col('label')=='1219770712064').show()

### Personalized PR



In [0]:
from pyspark.sql.functions import udf
from pyspark.sql.types import MapType, IntegerType, DoubleType


Main idea: compute the importance of nodes relative to a specific source node or set of nodes. Then, can find metrics for relevant nodes. Here, conceptualizing a node = an airport.



In [0]:
ydf.groupBy("ORIGIN").agg(F.countDistinct("DEST").alias("dest_count")).orderBy(F.desc("dest_count")).show()

In [0]:
from pyspark.sql.window import Window

# Vertices (ordered alphabetically)
v = ydf.select(F.col("ORIGIN").alias("id")) \
    .union(ydf.select(F.col("DEST").alias("id"))) \
    .distinct() \
    .orderBy("id")

# Edges (directed)
e = ydf.select(F.col("ORIGIN").alias("src"), F.col("DEST").alias("dst"))

# Build graph
g = GraphFrame(v, e)


In [0]:

# Sources (ordered alphabetically to match vertices)
sources_df = v.select("id").distinct() \
    .orderBy("id") \
    .withColumn("index", F.row_number().over(Window.orderBy("id")) - 1)

sources_flat = sources_df.select("id").rdd.flatMap(lambda x: x).collect()


In [0]:

# Run PPR with aligned sources
pageranked = g.parallelPersonalizedPageRank(
    resetProbability=0.15, 
    sourceIds=sources_flat, 
    maxIter=10
)


In [0]:
display(pageranked.vertices.filter(F.col('id')=='JFK'))

In [0]:
# Example: Check JFK's PPR scores
jfk_index = sources_flat.index("JFK")
jfk_scores = pageranked.vertices.select("pageranks").collect()[jfk_index]

# Map scores to airport names
sorted_scores = sorted(
    zip(sources_flat, jfk_scores), 
    key=lambda x: x[1], 
    reverse=True
)

# Exclude self-score (JFK)
top_external = sorted_scores[1:11]  # Skip index 0 (JFK itself)
print(top_external)


In [0]:
sources_flat[51]

In [0]:


broadcast_sources = sc.broadcast(sources_flat)

result_schema = ArrayType(
    StructType([
        StructField("origin", StringType()),
        StructField("score", DoubleType())
    ])
)
def vector_to_dict(vector):
    # Retrieve the broadcasted list
    sources = broadcast_sources.value
    
    # Sort the vector entries and take top 10
    sorted_entries = sorted(
        [(i, float(v)) for i, v in enumerate(vector)], 
        key=lambda x: x[1], 
        reverse=True
    )[:10]
    
    # Map indices to actual source IDs
    return [(sources_flat[i], float(v)) for i, v in sorted_entries]

# Define UDF
vector_to_dict_udf = udf(vector_to_dict, result_schema)




results = pageranked.vertices.withColumn("pagerank_dict", vector_to_dict_udf("pageranks"))


In [0]:
[u for u in enumerate(sources_flat) if u[1]=='HYA']

In [0]:
results.printSchema()

In [0]:
def vector_to_map(vector):
    sources = broadcast_sources.value
    indices = vector.indices.tolist()
    values = vector.values.tolist()

    return {sources[i]: float(v) for i, v in zip(indices, values)}

vector_to_map_udf = udf(vector_to_map, MapType(StringType(), DoubleType()))
results = results.withColumn("pagerank_map", vector_to_map_udf(col("pageranks")))


pageranks (VectorType): the pageranks of this vertex from all input source vertices


In [0]:
r

In [0]:
display(pageranked.edges.limit(10))

In [0]:
display(results)

In [0]:
display(results.withColumn('pagerank', F.col('pageranks').cast('string')).limit(5))

In [0]:

broadcast_sources = sc.broadcast(sources_list)

result_schema = ArrayType(
    StructType([
        StructField("origin", StringType()),
        StructField("score", DoubleType())
    ])
)
def vector_to_dict(vector):
    # Retrieve the broadcasted list
    sources = broadcast_sources.value
    
    # Sort the vector entries and take top 10
    sorted_entries = sorted(
        [(i, float(v)) for i, v in enumerate(vector)], 
        key=lambda x: x[1], 
        reverse=True
    )[:10]
    
    # Map indices to actual source IDs
    return [(sources[i], float(v)) for i, v in sorted_entries]

# Define UDF
vector_to_dict_udf = udf(vector_to_dict, result_schema)

# Apply UDF
results = results.withColumn("pagerank_dict", vector_to_dict_udf("pageranks"))
display(results.limit(5)) #sanity check

In [0]:
display(results.filter(F.col('id')=='JFK'))

In [0]:
ydf.filter(F.col("ORIGIN")=="JFK").filter(F.col("DEST")=="HYA").count()

In [0]:
df.filter(F.col("ORIGIN")=="JFK").groupBy("DEST").count().orderBy(F.desc(F.col('count'))).show()

In [0]:
df.filter(F.col("ORIGIN")=="HYA").groupBy("DEST").count().orderBy(F.desc(F.col('count'))).show()

In [0]:
df.filter(F.col("ORIGIN")=="HYA").filter(F.col("DEST")=="JFK").count()

In [0]:
df.filter(F.col("ORIGIN")=="JFK").filter(F.col("DEST")=="HYA").count()

In [0]:
[i for i in enumerate(sources_list) if i[1]=='HYA']

In [0]:
display(results.filter(F.col('id')=='BGM').select(F.map_keys('pagerank_dict').cast("string")))

In [0]:
display(results.limit(5))

In [0]:
results.select('id', 'pagerank_dict')

In [0]:
sources_list[200]

In [0]:
display(results.vertices.withColumn("pagerank_dict", vector_to_dict_udf("pageranks")).limit(5))

# PPR: Folds

In [0]:
df = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr_v2.parquet/")

In [0]:
df.write.mode('overwrite').parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr_v2.parquet/")

In [0]:
df = df.withColumn('priorflight_sched_elapsed', f.col('priorflight_sched_elapsed').cast('int')/60)

In [0]:
df.WithColumn('train',
              F.when(F.col('sched_depart_utc') >= folds[0]["date_min"] & F.col('sched_depart_utc') <= folds[0]["date_max"], F.col('train_0').
                     
                     )
              )

In [0]:
# Define fold date ranges based on erica's cv
folds = [
    {"fold": 'train_0', "date_min": "2014-12-31", "date_max": "2015-10-09"},
    {"fold": 'test_0', "date_min": "2015-10-09", "date_max": "2016-07-17"},
    {"fold": 'train_1', "date_min": "2015-08-14", "date_max": "2016-05-21"},
    {"fold": 'test_1', "date_min": "2016-05-21", "date_max": "2017-02-27"},
    {"fold": 'train_2', "date_min": "2016-03-27", "date_max": "2017-01-01"},
    {"fold": 'test_2', "date_min": "2017-01-01", "date_max": "2017-10-10"},
    {"fold": 'train_3', "date_min": "2016-11-08", "date_max": "2017-08-14"},
    {"fold": 'test_3', "date_min": "2017-08-14", "date_max": "2018-05-23"},
    {"fold": 'train_4', "date_min": "2017-06-22", "date_max": "2018-03-27"},
    {"fold": 'test_4', "date_min": "2018-03-27", "date_max": "2018-12-31"},
    {"fold":"test", "date_min": "2019-01-01", "date_max": "2019-12-31"}
    ]

In [0]:
from pyspark.sql import functions as F

df = df.withColumn('train',
                   F.when((F.col('sched_depart_utc') > folds[0]["date_min"]) & (F.col('sched_depart_utc') <= folds[0]["date_max"]), F.col(folds[0]["fold"]))
                   .when((F.col('sched_depart_utc') > folds[1]["date_min"]) & (F.col('sched_depart_utc') <= folds[1]["date_max"]), F.col(folds[1]["fold"]))
                   .when((F.col('sched_depart_utc') > folds[2]["date_min"]) & (F.col('sched_depart_utc') <= folds[2]["date_max"]), F.col(folds[2]["fold"]))
                   .when((F.col('sched_depart_utc') > folds[3]["date_min"]) & (F.col('sched_depart_utc') <= folds[3]["date_max"]), F.col(folds[3]["fold"]))
                   .when((F.col('sched_depart_utc') > folds[4]["date_min"]) & (F.col('sched_depart_utc') <= folds[4]["date_max"]), F.col(folds[4]["fold"]))
                   .when((F.col('sched_depart_utc') > folds[5]["date_min"]) & (F.col('sched_depart_utc') <= folds[5]["date_max"]), F.col(folds[5]["fold"]))
                   .when((F.col('sched_depart_utc') > folds[6]["date_min"]) & (F.col('sched_depart_utc') <= folds[6]["date_max"]), F.col(folds[6]["fold"]))
                   .when((F.col('sched_depart_utc') > folds[7]["date_min"]) & (F.col('sched_depart_utc') <= folds[7]["date_max"]), F.col(folds[7]["fold"]))
                   .when((F.col('sched_depart_utc') > folds[8]["date_min"]) & (F.col('sched_depart_utc') <= folds[8]["date_max"]), F.col(folds[8]["fold"]))
                   .when((F.col('sched_depart_utc') > folds[9]["date_min"]) & (F.col('sched_depart_utc') <= folds[9]["date_max"]), F.col(folds[9]["fold"]))
                   .when((F.col('sched_depart_utc') >= folds[10]["date_min"]), F.col(folds[10]["fold"]))
                  )

In [0]:
df.write.mode('overwrite')

In [0]:
train0= df.filter(F.col("sched_depart_utc")>="2014-12-31").filter(F.col("sched_depart_utc")<="2015-10-09")

## Pregel


In [0]:
from graphframes.lib import Pregel
from graphframes import GraphFrame


In [0]:
edges = train0.select(
    F.col("priorflight_origin").alias("src"),
    F.col("ORIGIN").alias("dst"),
    F.col("priorflight_sched_deptime").alias("timestamp"),
    F.col('priorflight_depdelay_calc')
).orderBy("timestamp")

In [0]:
vertices = train0.select(F.col("ORIGIN").alias("id")) \
    .union(train0.select(F.col("DEST").alias("id"))) \
    .distinct() \
    .orderBy("id")


In [0]:
# Define windows (e.g., hourly)
window_duration = "1 hour"
lookback_duration = "2 hours"

# Add window start/end timestamps to edges
edges_with_windows = edges.withColumn(
    "window", 
    F.window(F.col("timestamp"), window_duration)
).select(
    "src", "dst", "timestamp", "priorflight_depdelay_calc",
    F.col("window.start").alias("window_start"),
    F.col("window.end").alias("window_end")
)

# Generate all possible windows (sorted)
windows = edges_with_windows.select("window_start", "window_end").distinct() \
    .orderBy("window_start") \
    .collect()

In [0]:
display(edges_with_windows)

In [0]:
g = GraphFrame(vertices, edges)


In [0]:
numVertices = vertices.count()
alpha=.15
result = g.pregel \
    .withVertexColumn("rank", lit("priorflight_depdelay_calc"), \
              F.coalesce(Pregel.msg(), lit(0.0)) * lit(1.0 - alpha) + lit(alpha / numVertices)) \
    .sendMsgToDst(Pregel.src("rank")*.5) \
    .aggMsgs(F.sum(Pregel.msg())) \
    .run()

In [0]:
lookback_duration

In [0]:
vertices = edges.selectExpr("src as id").union(edges.selectExpr("dst as id")).distinct() \
    .withColumn("state", F.lit(0.0)) #initialize state=0

In [0]:
vertices = edges.selectExpr("src as id").union(edges.selectExpr("dst as id")).distinct() \
    .withColumn("state", F.lit(0.0)) #initialize state=0
for window in windows[:3]:
    current_window_end = window.window_end
    lookback_start = window.window_start - F.expr(f"INTERVAL {lookback_duration}")
    
    # Filter edges from [lookback_start, current_window_end]
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= lookback_start) & 
        (F.col("timestamp") < current_window_end)
    )
    
    # Build graph for this window
    g = GraphFrame(vertices, current_edges)
    
    # Run Pregel with "new_state" column
    result = g.pregel \
        .withVertexColumn(
            "new_state",
            F.col("state"),  # Now safe to reference "state"
            F.least(
                F.lit(1.0), 
                F.coalesce(F.sum(Pregel.msg()).over(Window.partitionBy('id')), F.lit(0.0)) * 0.85 + F.lit(0.15 / numVertices)
            )
        ) \
        .sendMsgToDst(Pregel.src("new_state") * Pregel.edge("priorflight_depdelay_calc")) \
        .aggMsgs(F.sum(Pregel.msg())) \
        .run()
    
    # Rename "new_state" to "state" for next iteration
    vertices = result.drop("state").withColumnRenamed("new_state", "state")

In [0]:
display(current_edges)

In [0]:
# Initialize airports with delay state = 0.0
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .union(train0.select(F.col("priorflight_origin").alias("id")))
    .distinct()
    .withColumn("delay_state", F.lit(0.0))
)


edges = train0.select(
    F.col("ORIGIN").alias("src"),  # Delay source
    F.col("DEST").alias("dst"),              # Current origin receiving delay
    F.col("sched_depart_utc").alias("timestamp"), 
    F.col("mean_dep_delay").alias("mean_delay")
).withColumn(F.greatest(F.col("mean_delay"), F.lit(0.0))).orderBy("timestamp")


edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "1 hour")
).withColumn(
    "lookback_start", 
    F.col("window.start") - F.expr("INTERVAL 2 HOURS")
)
windows = edges_with_windows.select("window.start", "window.end").distinct().orderBy("window.start").collect()


In [0]:
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "1 hour")
).withColumn(
    "lookback_start", 
    F.col("window.start") - F.expr("INTERVAL 2 HOURS")
)

# Get sorted list of time windows
windows = edges_with_windows.select("window.start", "window.end").distinct().orderBy("window.start").collect()

In [0]:
train0.select("prop_delayed").summary().show()

In [0]:
aircraft_window = Window.partitionBy("TAIL_NUM").orderBy('sched_depart_utc')
train0 = train0.withColumn("priorflight_prop_delay",
                  F.lag(F.col("prop_delayed")).over(aircraft_window))

In [0]:
train0.filter(F.col("priorflight_prop_delay").isNull()).count()

In [0]:
train0.filter(F.col("priorflight_prop_delay").isNotNull()).count()

In [0]:
train0.select("priorflight_depdelay_calc").summary().show()

In [0]:
display(train0.filter(F.col('prop_delayed')==1))

In [0]:
# Initialize airports with delay state = 0.0
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .union(train0.select(F.col("priorflight_origin").alias("id")))
    .distinct()
    .withColumn("delay_state", F.lit(0.0))
)


edges = train0.select(
    F.col("priorflight_origin").alias("src"),  # Delay source
    F.col("ORIGIN").alias("dst"),              # Current origin receiving delay
    F.col("prop_delayed"),                     # Proportion of flights being "currently" delayed at origin
    F.col("sched_depart_utc").alias("timestamp"), 
    F.col("priorflight_depdelay_calc").alias("prior_delay_calc") # amount of delay being passed on
).withColumn("priorflight_depdelay_calc", #trim at 0
             F.greatest(F.col("prior_delay_calc"), F.lit(0.0))
).orderBy("timestamp")


edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "1 hour")
).withColumn(
    "lookback_start", 
    F.col("window.start") - F.expr("INTERVAL 2 HOURS")
)
windows = edges_with_windows.select("window.start", "window.end").distinct().orderBy("window.start").collect()


In [0]:
# Initialize airports with delay state = 0.0
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .union(train0.select(F.col("priorflight_origin").alias("id")))
    .distinct()
    .withColumn("delay_state", F.lit(0.0))
)


edges = train0.select(
    F.col("priorflight_origin").alias("src"),  # Delay source
    F.col("ORIGIN").alias("dst"),              # Current origin receiving delay
    F.col("prop_delayed"),                     # Proportion of flights being "currently" delayed at origin
    F.col("sched_depart_utc").alias("timestamp"), 
    F.col("priorflight_depdelay_calc").alias("prior_delay_calc") # amount of delay being passed on
).withColumn("priorflight_depdelay_calc", #trim at 0
             F.greatest(F.col("prior_delay_calc"), F.lit(0.0))
).orderBy("timestamp")


edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "1 hour")
).withColumn(
    "lookback_start", 
    F.col("window.start") - F.expr("INTERVAL 2 HOURS")
)
windows = edges_with_windows.select("window.start", "window.end").distinct().orderBy("window.start").collect()

for window in windows[10:13]:
    print(f'processing window {window}')
    # Filter edges in [current_window - 2 hours, current_window]
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("lookback_start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )

    # Group by (src, dst) to retain edge structure
    current_edges = current_edges.groupBy("src", "dst").agg(
        F.avg("priorflight_depdelay_calc").alias("avg_delay_edge"),
        F.first("prop_delayed")  # Keep prop_delayed
    )

    
    # Compute max delay and normalize
    max_delay = current_edges.agg(F.max("avg_delay_edge")).first()[0]
    current_edges = current_edges.withColumn(
        "delay_risk_edge", 
        F.col("prop_delayed") * (F.col("avg_delay_edge") / max_delay)
    )
    
    current_edges.select('delay_risk_edge').summary().show()

    # Build graph and run Pregel (remaining code unchanged)
    g = GraphFrame(vertices, current_edges)
    result = g.pregel \
        .withVertexColumn(
            "new_delay_state",
            F.col("delay_state"),
            F.coalesce(Pregel.msg(), F.lit(0.0)) * 0.8 + 0.2 / vertices.count()
        ) \
        .sendMsgToDst(Pregel.src("new_delay_state") * Pregel.edge("delay_risk_edge")) \
        .aggMsgs(F.sum(Pregel.msg())) \
        .setMaxIter(5) \
        .run()
    
    vertices = result.drop("delay_state").withColumnRenamed("new_delay_state", "delay_state")
    vertices.select("delay_state").summary().show()


In [0]:
result.cache()

In [0]:
display(result)

In [0]:
# Initialize airports with delay state = 0.0
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .union(train0.select(F.col("priorflight_origin").alias("id")))
    .distinct()
    .withColumn("delay_state", F.lit(0.0))
)


edges = train0.select(
    F.col("priorflight_origin").alias("src"),  # Delay source
    F.col("ORIGIN").alias("dst"),              # Current origin receiving delay
    F.col("prop_delayed"),                     # Proportion of flights being "currently" delayed at origin
    F.col("sched_depart_utc").alias("timestamp"), 
    F.col("priorflight_depdelay_calc").alias("prior_delay_calc") # amount of delay being passed on
).withColumn("priorflight_depdelay_calc", #trim at 0
             F.greatest(F.col("prior_delay_calc"), F.lit(0.0))
).orderBy("timestamp")


edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "1 hour")
).withColumn(
    "lookback_start", 
    F.col("window.start") - F.expr("INTERVAL 2 HOURS")
)
windows = edges_with_windows.select("window.start", "window.end").distinct().orderBy("window.start").collect()

for window in windows[10:13]:
    print(f'processing window {window}')
    # Filter edges in [current_window - 2 hours, current_window]
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("lookback_start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )

    # Group by (src, dst) to retain edge structure
    current_edges = current_edges.groupBy("src", "dst").agg(
        F.avg("priorflight_depdelay_calc").alias("avg_delay_edge"),
        F.first("prop_delayed")  # Keep prop_delayed
    )

    current_edges.select('delay_risk_edge').summary().show()

    # Build graph and run Pregel (remaining code unchanged)
    g = GraphFrame(vertices, current_edges)
    result = g.pregel \
        .withVertexColumn(
            "new_delay_state",
            F.col("delay_state"),
            F.coalesce(Pregel.msg(), F.lit(0.0)) * 0.8 + 0.2 / vertices.count()
        ) \
        .sendMsgToDst(Pregel.src("new_delay_state") * Pregel.edge("delay_risk_edge")) \
        .aggMsgs(F.sum(Pregel.msg())) \
        .setMaxIter(5) \
        .run()
    
    vertices = result.drop("delay_state").withColumnRenamed("new_delay_state", "delay_state")
    vertices.select("delay_state").summary().show()


In [0]:
# Initialize airports with delay state = 0.0
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_state", F.lit(0.0))
)


edges = train0.select(
    F.col("priorflight_origin").alias("src"),  # Delay source
    F.col("ORIGIN").alias("dst"),              # Current origin receiving delay
    F.col("prop_delayed"),                     # Proportion of flights being "currently" delayed at origin
    F.col("sched_depart_utc").alias("timestamp"), 
    F.col("priorflight_depdelay_calc").alias("prior_delay_calc") # amount of delay being passed on
).orderBy("timestamp")


edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "1 hour")
).withColumn(
    "lookback_start", 
    F.col("window.start") - F.expr("INTERVAL 2 HOURS")
)
windows = edges_with_windows.select("window.start", "window.end").distinct().orderBy("window.start").collect()

for window in windows[10:13]:
    print(f'processing window {window}')
    # Filter edges in [current_window - 2 hours, current_window]
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("lookback_start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )

    # # Group by (src, dst) to retain edge structure
    # current_edges = current_edges.groupBy("src", "dst").agg(
    #     F.avg("priorflight_depdelay_calc").alias("avg_delay_edge"),
    #     F.first("prop_delayed")  # Keep prop_delayed
    # )



    # Build graph and run Pregel (remaining code unchanged)
    g = GraphFrame(vertices, current_edges)
    result = g.pregel \
        .withVertexColumn(
            "new_delay_state",
            F.col("delay_state"),
            F.coalesce(Pregel.msg(), F.lit(0.0)) * 0.8
        ) \
        .sendMsgToDst(Pregel.src("new_delay_state") * Pregel.edge("prop_delayed")) \
        .aggMsgs(F.sum(Pregel.msg())) \
        .setMaxIter(3) \
        .run()
    
    vertices = result.drop("delay_state").withColumnRenamed("new_delay_state", "delay_state")
    vertices.select("delay_state").summary().show()


In [0]:
edges_with_windows

In [0]:
current_edges.cache()

In [0]:
current_edges.select('delay_risk_edge').summary().show()

In [0]:
display(result)

In [0]:
display(result)

In [0]:
train0.select('prop_delayed').summary().show()

In [0]:
.withVertexColumn(
            "_tmp_prop",
            F.col("_tmp_src_prop"),
            (F.col("_tmp_src_prop") * 0.9) + F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.0)) * 0.1
        ) \
        .withVertexColumn(
            "_tmp_delay",
            F.col("delay_load"),
            (F.col("_tmp_delay") * 0.2) + F.coalesce(Pregel.msg().getItem("avg_delay"), F.lit(0.0))
        ) \
        .sendMsgToDst(
            F.struct(
                Pregel.edge("edge_prop_delayed").alias("prop"),
                (
                    Pregel.src("_tmp_src_prop") *  # Use temporary name
                    Pregel.edge("edge_prop_delayed") * 
                    Pregel.edge("delay_load")
                ).alias("delay")
            )
        ) \
        .aggMsgs(
            F.struct(
                F.coalesce(F.avg(Pregel.msg().getItem("prop")), F.lit(0.0)).alias("avg_prop"),
                F.coalesce(F.avg(Pregel.msg().getItem("delay")), F.lit(0.0)).alias("avg_delay")
            )
        ) \
        .setMaxIter(1) \
        .run()

In [0]:
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(0.0))
    .withColumn("prop_delayed", F.lit(0.0))  # Vertex property
)


edges = train0.select(
    F.col("priorflight_origin").alias("src"),
    F.col("ORIGIN").alias("dst"),
    F.col("prop_delayed").alias("edge_prop_delayed"), 
    F.col("sched_depart_utc").alias("timestamp"), 
    F.col("priorflight_depdelay_calc").alias("delay_load")
).withColumn(
    "delay_load", 
    F.greatest(F.col("delay_load"), F.lit(0.0))
).orderBy("timestamp")

# Add temporal windows (unchanged)
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "1 hour")
).withColumn(
    "lookback_start", 
    F.col("window.start") - F.expr("INTERVAL 2 HOURS")
)

windows = edges_with_windows.select("window.start", "window.end") \
    .distinct().orderBy("window.start").collect()


In [0]:
current_edges

In [0]:
        .withVertexColumn(
            "src_prop_delayed",
            F.col("prop_delayed"),
            (F.col("prop_delayed") * 0.9) + 
            F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.0)) * 0.1
        ) \
        .withVertexColumn(
            "delay_load",
            F.col("delay_load"),
            (F.col("delay_load") * 0.2) + 
            F.coalesce(Pregel.msg().getItem("sum_delay"), F.lit(0.0))
        ) \
        .sendMsgToDst(
            F.struct(
                Pregel.edge("edge_prop_delayed").alias("prop"),
                (
                    Pregel.src("vertex_src_prop_delayed") * 
                    Pregel.edge("edge_prop_delayed") * 
                    Pregel.edge("delay_load")
                ).alias("delay")
            )
        ) \
        .aggMsgs(
            F.struct(
                F.coalesce(F.avg(Pregel.msg().getItem("prop")), F.lit(0.0)).alias("avg_prop"),
                F.coalesce(F.sum(Pregel.msg().getItem("delay")), F.lit(0.0)).alias("sum_delay")
            )
        ) \
        .setMaxIter(3) \
        .run()

In [0]:
train0.columns

In [0]:
for window in windows[10:13]:
    print(f"Processing window {window}")
    
    # 1. Filter edges in the current window
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("lookback_start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )

    # 2. Precompute source's average edge_prop_delayed
    src_prop = current_edges.groupBy("src").agg(
        F.avg("edge_prop_delayed").alias("_tmp_src_prop")
    )
    
    # 3. Join with vertices and fill nulls
    vertices_window = vertices.join(
        src_prop,
        vertices.id == src_prop.src,
        "left"
    ).fillna(0.0, subset=["_tmp_src_prop"])

    # 4. Build graph
    g = GraphFrame(vertices_window, current_edges)

    # 5. Run Pregel with structured messages
    result = g.pregel \
        .withVertexColumn(
            "new_delay_state",
            F.col("delay_load"),
            (F.col("delay_load") * 0.2) +  # 80% decay
            F.coalesce(Pregel.msg().getItem("sum_delay"), F.lit(0.0))
        ) \
        .withVertexColumn(
            "_tmp_prop",
            F.col("_tmp_src_prop"),
            (F.col("_tmp_prop") * 0.9) +  # 10% decay
            F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.0)) * 0.1
        ) \
        .sendMsgToDst(
            F.struct(
                (Pregel.edge("edge_prop_delayed") * 0.85).alias("prop"),  # Scaled by decay
                (Pregel.edge("delay_load") * Pregel.src("_tmp_prop")).alias("delay")
            )
        ) \
        .aggMsgs(
            F.struct(
                F.avg(Pregel.msg().getItem("prop")).alias("avg_prop"),
                F.avg(Pregel.msg().getItem("delay")).alias("sum_delay")
            )
        ) \
        .setMaxIter(3) \
        .run()
    
    # 7. Show results
    print("\nFinal vertex states:")
    result.orderBy(F.desc('new_delay_state')).show(10, truncate=False)

    # Checkpoint the vertices DataFrame


In [0]:
# Initialize vertices with non-zero values to force state changes
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(0.001))  # Non-zero initial state
    .withColumn("prop_delayed", F.lit(0.001))
)

for window in windows[10:13]:
    print(f"Processing window {window}")
    
    # 1. Filter edges in the current window
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("lookback_start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )

    # 2. Precompute source's average edge_prop_delayed
    src_prop = current_edges.groupBy("src").agg(
        F.avg("edge_prop_delayed").alias("_tmp_src_prop")
    )
    
    # 3. Join with vertices and fill nulls
    vertices_window = vertices.join(
        src_prop,
        vertices.id == src_prop.src,
        "left"
    ).fillna(0.0, subset=["_tmp_src_prop"])

    # 4. Build graph
    g = GraphFrame(vertices_window, current_edges)

    # 5. Run Pregel with corrected aggregation and decay
    result = g.pregel \
        .withVertexColumn(
            "new_delay_state",
            F.col("delay_load"),
            (F.col("delay_load") * 0.2) + 
            F.coalesce(Pregel.msg().getItem("avg_delay"), F.lit(0.0))
        ) \
        .withVertexColumn(
            "_tmp_prop",
            F.col("_tmp_src_prop"),
            (F.col("_tmp_prop") * 0.85) +
            F.coalesce(Pregel.msg().getItem("avg_prop")*15, F.lit(0.0)) 
        ) \
        .sendMsgToDst(
            F.struct(
                (Pregel.edge("edge_prop_delayed") * 0.85).alias("prop"),
                (Pregel.edge("delay_load") * Pregel.src("_tmp_prop")).alias("delay")
            )
        ) \
        .aggMsgs(
            F.struct(
                F.avg(Pregel.msg().getItem("prop")).alias("avg_prop"),
                F.avg(Pregel.msg().getItem("delay")).alias("avg_delay")
            )
        ) \
        .setMaxIter(3) \
        .run()
        
    vertices = result.select(
        "id",
        F.col("_tmp_prop").alias("prop_delayed"),
        F.col("new_delay_state").alias("delay_load")
    ).cache()  # Cache to ensure updates persist
    
    # 7. Force materialization to prevent plan recomputation
    vertices.count()


    print("\nFinal vertex states:")
    vertices.orderBy(F.desc("delay_load")).show(10, truncate=False)



In [0]:
from pyspark.sql import functions as F
from graphframes import GraphFrame
from graphframes.lib import Pregel

# Initialize vertices with non-zero values to force state changes
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(0.001))  # Avoid zero initialization
    .withColumn("prop_delayed", F.lit(0.001))
)

# Define edges with raw delay minutes (no normalization)
edges = train0.select(
    F.col("priorflight_origin").alias("src"),
    F.col("priorflight_dest").alias("dst"),
    F.col("prop_delayed").alias("edge_prop_delayed"),
    F.col("sched_depart_utc").alias("timestamp"), 
    F.col("priorflight_depdelay_calc").alias("delay_load")
)

# Add temporal windows (unchanged)
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "1 hour")
).withColumn(
    "lookback_start", 
    F.col("window.start") - F.expr("INTERVAL 2 HOURS")
)

windows = edges_with_windows.select("window.start", "window.end") \
    .distinct().orderBy("window.start").collect()

for window in windows[1:13]:
    print(f"Processing window {window}")
    
    # 1. Filter edges in the current window
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("lookback_start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )

    # 2. Build graph using previous state
    g = GraphFrame(vertices, current_edges)
    g.edges.show()

    # 3. Run Pregel with corrected logic
    result = g.pregel \
        .withVertexColumn(
            "new_delay_state",
            F.col("delay_load"),
            (F.col("delay_load")*.1 + 
                F.coalesce(Pregel.msg().getItem("avg_delay"), 
                           F.lit(0.0))
                )  # Sum delays
        ) \
        .withVertexColumn(
            "new_prop_delayed",
            F.col("prop_delayed"),
            (F.col("prop_delayed") * 0.1) + 
            F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.0)) * 0.1  # 70% new data
        ) \
        .sendMsgToDst(
            F.struct(
                Pregel.edge("edge_prop_delayed").alias("prop"),  # For prop_delayed updates
                (Pregel.edge("delay_load") * Pregel.src("new_prop_delayed")).alias("delay")  # Delay scaled by src's prop_delayed
            )
        ) \
        .aggMsgs(
            F.struct(
                F.avg(Pregel.msg().getItem("prop")*.15).alias("avg_prop"),  # Average edge_prop_delayed
                F.avg(Pregel.msg().getItem("delay")*.15).alias("avg_delay")  # Average scaled delays
            )
        ) \
        .setMaxIter(1) \
        .run()


    
    # 4. Update vertices for next window (cap values if needed)
    vertices = result.select(
        "id",
        F.col("new_prop_delayed").alias("prop_delayed"),
        F.col("new_delay_state").alias("delay_load")
    ).cache()
    
    # 5. Force materialization to persist state
    vertices.count()

    # 6. Show results
    print("\nFinal vertex states:")
    vertices.orderBy(F.desc("delay_load")).show(10, truncate=False)


In [0]:
from pyspark.sql import functions as F
from graphframes import GraphFrame
from graphframes.lib import Pregel

# Initialize vertices with non-zero values to force state changes
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(0.001))  # Avoid zero initialization
    .withColumn("prop_delayed", F.lit(0.001))
)

# Define edges with raw delay minutes (no normalization)
edges = train0.select(
    F.col("priorflight_origin").alias("src"),
    F.col("priorflight_dest").alias("dst"),
    F.col("prop_delayed").alias("edge_prop_delayed"),
    F.col("sched_depart_utc").alias("timestamp"), 
    F.col("priorflight_depdelay_calc").alias("delay_load")
)

# Add temporal windows (unchanged)
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "1 hour")
)

windows = edges_with_windows.select("window.start", "window.end") \
    .distinct().orderBy("window.start").collect()

for window in windows[1:13]:
    print(f"Processing window {window}")
    
    # 1. Filter edges in the current window
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("window.start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )

    # 2. Build graph using previous state
    g = GraphFrame(vertices, current_edges)


    # 3. Run Pregel with corrected logic
    result = g.pregel \
        .withVertexColumn(
            "new_delay_state",
            F.col("delay_load"),
            (F.col("delay_load")*.1 + 
                F.coalesce(Pregel.msg().getItem("avg_delay"), 
                           F.lit(0.0))
                )  # Sum delays
        ) \
        .withVertexColumn(
            "new_prop_delayed",
            F.col("prop_delayed"),
            (F.col("prop_delayed") * 0.1) + 
            F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.0)) * 0.1  # 70% new data
        ) \
        .sendMsgToDst(
            F.struct(
                (F.when(Pregel.src("new_delay_state")>15, 1).otherwise(0)).alias("prop"),  # For prop_delayed updates
                (Pregel.edge("delay_load") * .8).alias("delay")  # Delay scaled by src's prop_delayed
            )
        ) \
        .aggMsgs(
            F.struct(
                F.avg(Pregel.msg().getItem("prop")).alias("avg_prop"),  # Average edge_prop_delayed
                F.avg(Pregel.msg().getItem("delay")).alias("avg_delay")  # Average scaled delays
            )
        ) \
        .setMaxIter(1) \
        .run()


    
    # 4. Update vertices for next window (cap values if needed)
    vertices = result.select(
        "id",
        F.col("new_prop_delayed").alias("prop_delayed"),
        F.col("new_delay_state").alias("delay_load")
    ).cache()
    
    # 5. Force materialization to persist state
    vertices.count()

    # 6. Show results
    print("\nFinal vertex states:")
    vertices.orderBy(F.desc("delay_load")).show(10, truncate=False)


In [0]:
display(edges)

In [0]:


# Initialize vertices with non-zero values to force state changes
vertices = (
    train0.select(F.col("priorflight_origin").alias("id"))
    .union(train0.select(F.col("priorflight_dest").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(0.0))
    .withColumn("prop_delayed", F.lit(0.2))
)

# Define edges with raw delay minutes (no normalization)
edges = train0.select(
    F.col("priorflight_origin").alias("src"),
    F.col("priorflight_dest").alias("dst"),
    F.col("prop_delayed").alias("edge_prop_delayed"),
    F.col("sched_depart_utc").alias("timestamp"), 
    F.col("priorflight_depdelay_calc").alias("delay_load")
)


# Add temporal windows (unchanged)
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "4 hours")
).orderBy("sched_depart_utc")


edges_with_windows = edges_with_windows.limit(100)
edges_with_windows.cache()



windows = edges_with_windows.select("window.start", "window.end") \
    .distinct().orderBy('window.start').collect()

for window in windows[0:13]:
    print(f"Processing window {window}")
    
    # 1. Filter edges in the current window
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("window.start")) & 
        (F.col("timestamp") < F.col("window.end"))
    ).limit(50)


    # 2. Build graph using previous state
    g = GraphFrame(vertices, current_edges)


    # 3. Run Pregel with corrected logic
    result = g.pregel \
        .withVertexColumn(
            "new_delay_state",
            F.col("delay_load"),
            F.col("delay_load")*.1 + 
            F.coalesce(Pregel.msg().getItem("avg_delay")*.8, 
                           F.lit(0.0)) # Sum delays
        ) \
        .withVertexColumn(
            "new_prop_delayed",
            F.col("prop_delayed"),
            F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.0))
        ) \
        .sendMsgToDst(
            F.struct(
                Pregel.src("new_prop_delayed").alias("prop"),  # For prop_delayed updates
                (Pregel.edge("delay_load") * .5).alias("delay")  # Delay scaled by src's prop_delayed
            )
        ) \
        .aggMsgs(
            F.struct(
                F.avg(F.when(Pregel.msg().getItem("delay") > 15, 1).otherwise(0)).alias("avg_prop"),  # Average edge_prop_delayed
                F.avg(Pregel.msg().getItem("delay")).alias("avg_delay")  # Average scaled delays
            )
        ) \
        .setMaxIter(3) \
        .run()

    print(f'result:')
    result.orderBy(F.desc('new_delay_state')).show()
    # 4. Update vertices for next window (cap values if needed)
    vertices = result.select(
        "id",
        F.col("new_prop_delayed").alias("prop_delayed"),
        F.col("new_delay_state").alias("delay_load")
    ).cache()
    
    # 5. Force materialization to persist state
    vertices.count()




# WORKING VERSION

In [0]:
folds = [
    {"fold": 'train_0', "date_min": "2014-12-31", "date_max": "2015-10-09"},
    {"fold": 'test_0', "date_min": "2015-10-09", "date_max": "2016-07-17"},
    {"fold": 'train_1', "date_min": "2015-08-14", "date_max": "2016-05-21"},
    {"fold": 'test_1', "date_min": "2016-05-21", "date_max": "2017-02-27"},
    {"fold": 'train_2', "date_min": "2016-03-27", "date_max": "2017-01-01"},
    {"fold": 'test_2', "date_min": "2017-01-01", "date_max": "2017-10-10"},
    {"fold": 'train_3', "date_min": "2016-11-08", "date_max": "2017-08-14"},
    {"fold": 'test_3', "date_min": "2017-08-14", "date_max": "2018-05-23"},
    {"fold": 'train_4', "date_min": "2017-06-22", "date_max": "2018-03-27"},
    {"fold": 'test_4', "date_min": "2018-03-27", "date_max": "2018-12-31"},
    {"fold":"test", "date_min": "2019-01-01", "date_max": "2019-12-31"}
    ]

In [0]:
train0_msg = train0.select(
    F.col("ORIGIN").alias("src"),
    F.col("DEST").alias("dst"),
    F.col("actual_arr_utc")
)

In [0]:
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(0.0))
    .withColumn("prop_delayed", F.lit(0.2))
)

# Define edges with raw delay minutes (no normalization)
edges = train0.select(
    F.col("ORIGIN").alias("src"),
    F.col("DEST").alias("dst"),
    F.col("prop_delayed").alias("edge_prop_delayed"),
    F.col("actual_arr_utc").alias("timestamp"), 
    F.col("DEP_DELAY").alias("delay_load")
)


# Add temporal windows (unchanged)
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "4 hours")
).orderBy("timestamp")

edges_with_windows.cache()
edges_with_windows.count()
windows = edges_with_windows.select("window.start", "window.end") \
    .distinct().orderBy('window.start').collect()

for window in windows:
    print(f"Processing window {window}")
    
    # 1. Filter edges in the current window
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.lit(window.start)) & 
        (F.col("timestamp") < F.lit(window.end))
        )

    # 2. Build graph using previous state
    g = GraphFrame(vertices, current_edges)

    # 3. Run Pregel with corrected logic
    result = g.pregel \
        .withVertexColumn(
            "new_delay_state",
            F.col("delay_load"),
            F.col("delay_load")*.1 + 
            F.coalesce(Pregel.msg().getItem("avg_delay")*.8, 
                           F.lit(0.0)) # Sum delays
        ) \
        .withVertexColumn(
            "new_prop_delayed",
            F.col("prop_delayed"),
            F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.0))
        ) \
        .sendMsgToDst(
            F.struct(
                Pregel.src("new_prop_delayed").alias("prop"),  # For prop_delayed updates
                (Pregel.edge("delay_load") * .5).alias("delay")  # Delay scaled by src's prop_delayed
            )
        ) \
        .aggMsgs(
            F.struct(
                F.avg(F.when(Pregel.msg().getItem("delay") > 15, 1).otherwise(0)).alias("avg_prop"),  # Average edge_prop_delayed
                F.avg(Pregel.msg().getItem("delay")).alias("avg_delay")  # Average scaled delays
            )
        ) \
        .setMaxIter(3) \
        .run()

    # 4. Update vertices for next window (cap values if needed)
    vertices = result.select(
        "id",
        F.col("new_prop_delayed").alias("prop_delayed"),
        F.col("new_delay_state").alias("delay_load")
    ).cache()
    
    # 5. Force materialization to persist state
    vertices.count()


In [0]:
F.lit(window.start)

In [0]:
F.lit(window.end)

In [0]:
display(train0.withColumn("window_end", F.lit(window.end)).withColumn("window_start",F.lit(window.start)).limit(10))

# RDD WORKING VERS

In [0]:
team_BASE_DIR = f"dbfs:/student-groups/Group_4_1"
spark.sparkContext.setCheckpointDir(f"{team_BASE_DIR}/interim")

In [0]:
from datetime import datetime, timedelta
from pyspark.mllib.linalg import Vectors


In [0]:
%scala
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
import java.sql.Timestamp

case class Flight(origin: String, dest: String, depDelay: Double, timestamp: Timestamp)

val flights: RDD[Flight] = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr_v2.parquet/")
  .withColumnRenamed("actual_arr_utc", "timestamp")
  .withColumnRenamed("ORIGIN", "origin")
  .withColumnRenamed("DEST", "dest")
  .na.drop(Seq("origin", "dest"))
  .na.fill(Map("DEP_DELAY" -> 0))
  .withColumnRenamed("DEP_DELAY", "depDelay")
  .limit(100)
  .as[Flight]
  .rdd
flights.take(10)

// Assign 4-hour windows
val windowedEdges = flights.map { f =>
  val windowStart = f.timestamp.getTime - (f.timestamp.getTime % (4 * 60 * 60 * 1000))
  (windowStart, Edge(f.origin.hashCode.toLong, f.dest.hashCode.toLong, f.depDelay))
}

val edgesByWindow = windowedEdges.groupByKey()

val vertices = flights.flatMap(f => Seq(f.origin, f.dest))
  .distinct()
  .map { airportId =>
    (airportId.hashCode.toLong, (0.0, 0.2)) // (id, (delayLoad, propDelayed))
  }

In [0]:
%scala
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
import java.sql.Timestamp


case class Flight(origin: String, dest: String, depDelay: Double, timestamp: Timestamp)

val flights: RDD[Flight] = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr_v2.parquet/")
  .withColumnRenamed("actual_arr_utc", "timestamp")
  .withColumnRenamed("ORIGIN", "origin")
  .withColumnRenamed("DEST", "dest")
  .na.drop(Seq("origin", "dest"))
  .na.fill(Map("DEP_DELAY" -> 0))
  .withColumnRenamed("DEP_DELAY", "depDelay")
  .limit(100)
  .as[Flight]
  .rdd



## start here

In [0]:
df = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr_v2.parquet/")

In [0]:
%scala
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
import java.sql.Timestamp
import org.apache.spark.sql.functions.{coalesce, col}


In [0]:
%scala
case class Flight(origin: String, dest: String, depDelay: Double, timestamp: Timestamp)


In [0]:
%scala

// 1. Load data with coalesced timestamps
val flights: RDD[Flight] = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr_v2.parquet/")
  .na.drop(Seq("ORIGIN", "DEST")) // Filter rows with null origins/destinations first
  .withColumn("timestamp", coalesce(col("actual_arr_utc"), col("sched_arr_utc")).cast("timestamp"))
  .filter(col("timestamp").isNotNull) // Remove rows with both timestamps null
  .withColumnRenamed("ORIGIN", "origin")
  .withColumnRenamed("DEST", "dest")
  .na.fill(Map("DEP_DELAY" -> 0))
  .withColumnRenamed("DEP_DELAY", "depDelay")
  .as[Flight]
  .rdd

// 2. Verify flight data
println(s"Flight count after coalesce: ${flights.count()}")
// flights.limit(5).foreach { f =>
//   println(s"Flight: ${f.origin}->${f.dest} | Timestamp: ${f.timestamp} | Delay: ${f.depDelay}")
// }

// 3. Rebuild airport ID map from FULL dataset
val airportCodes = flights.flatMap(f => Seq(f.origin, f.dest)).distinct().collect()


val airportIdMap = airportCodes.zipWithIndex.map { case (code, idx) => 
  (code, idx.toLong) 
}.toMap
println(s"Airport ID map size: ${airportIdMap.size}")

// val bcAirportIds = spark.sparkContext.broadcast(airportIdMap)

// 4. Debug edge creation with coalesced timestamps
val windowedEdges = flights.map { f =>
  val srcId = airportIdMap(f.origin)
  val dstId = airportIdMap(f.dest)
  val windowStart = f.timestamp.getTime - (f.timestamp.getTime % (4 * 60 * 60 * 1000))
  (windowStart, Edge(srcId, dstId, f.depDelay))
}

// 5. Print edges with type annotations
println("First 10 windowed edges:")
windowedEdges.take(10).foreach { 
  case (window: Long, edge: Edge[Double]) =>
    val windowTime = new Timestamp(window).toString
    val origin = airportIdMap.find(_._2 == edge.srcId).map(_._1).getOrElse("UNKNOWN")
    val dest = airportIdMap.find(_._2 == edge.dstId).map(_._1).getOrElse("UNKNOWN")
    println(s"$windowTime: ${origin}(${edge.srcId}) -> ${dest}(${edge.dstId}) | Delay: ${edge.attr} mins")
}



In [0]:
%scala
val edgesByWindow = windowedEdges.groupByKey()

val vertices = airportCodes.map { code =>
  (airportIdMap(code), (0.0, 0.2))  // (id, (delayLoad, propDelayed))
}.toSeq
val verticesRDD = spark.sparkContext.parallelize(vertices)

var currentVertices: RDD[(VertexId, (Double, Double))] = verticesRDD
val windowOrder = edgesByWindow.keys.collect().sorted



In [0]:
%scala
// Define message type to track sum and count for averages
case class Message(sumProp: Double, sumDelay: Double, count: Long)

// Class for saving results
case class VertexResult(
    id: Long, 
    delay_load: Double, 
    prop_delayed: Double, 
    window_start: Timestamp
)


// Vertex program definition
def vProg(vertexId: VertexId, state: (Double, Double), msg: Message): (Double, Double) = {
    println(s"vProg called for vertex $vertexId with state $state and msg $msg")
    if (msg.count == 0) state else {
        val avgProp = msg.sumProp / msg.count
        val avgDelay = msg.sumDelay / msg.count
        val newDelay = state._1 * 0.1 + avgDelay * 0.8
        val newProp = avgProp
        (newDelay, newProp)
    }
}

// Message sending function
def sendMsg(triplet: EdgeTriplet[(Double, Double), Double]): Iterator[(VertexId, Message)] = {
  println(s"sendMsg: src=${triplet.srcId}, dst=${triplet.dstId}, attr=${triplet.attr}")
    val srcProp = triplet.srcAttr._2
    val scaledDelay = triplet.attr * 0.5
    Iterator((triplet.dstId, Message(srcProp, scaledDelay, 1L)))
}

// Message merging function
def mergeMsg(a: Message, b: Message): Message = {
    Message(
        a.sumProp + b.sumProp,
        a.sumDelay + b.sumDelay, 
        a.count + b.count
    )
}

In [0]:
%scala


windowOrder.foreach { windowStart =>

  println(s"Window starting ${windowStart}")
  val edges = edgesByWindow.lookup(windowStart).flatMap(_.iterator)
  
  val graph = Graph(currentVertices, spark.sparkContext.parallelize(edges))
  
  val updatedGraph = graph.pregel[Message](
    Message(0.0, 0.0, 0L),  // Initial message
    3,                       // Max iterations (neighbors of neighbors of neighbors)
    EdgeDirection.Out        // Active direction
  )(
    vProg _,    // Convert methods to function values
    sendMsg _,
    mergeMsg _
  )

  currentVertices = updatedGraph.vertices
  .mapValues { (v: (Double, Double)) => 
    v  // Identity function with explicit type
  }.cache()

  // Save results to Parquet
  currentVertices
    .map { case (id, (delay, prop)) =>
      VertexResult(id, delay, prop, new Timestamp(windowStart))
    }
    .toDF()
    .write
    .mode("append")
    .parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices.parquet")
}

In [0]:
v = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices.parquet")

display(v)

## w/o avg prop

In [0]:
%scala


// 1. Load data with coalesced timestamps
case class Flight(origin: String, dest: String, depDelay: Double, timestamp: Timestamp)

val flights: RDD[Flight] = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr_v2.parquet/")
  .na.drop(Seq("ORIGIN", "DEST")) // Filter rows with null origins/destinations first
  .withColumn("timestamp", coalesce(col("actual_arr_utc"), col("sched_arr_utc")).cast("timestamp"))
  .filter(col("timestamp").isNotNull) // Remove rows with both timestamps null
  .withColumnRenamed("ORIGIN", "origin")
  .withColumnRenamed("DEST", "dest")
  .na.fill(Map("DEP_DELAY" -> 0))
  .withColumnRenamed("DEP_DELAY", "depDelay")
  .as[Flight]
  .rdd

// 2. Verify flight data
println(s"Flight count after coalesce: ${flights.count()}")
// flights.limit(5).foreach { f =>
//   println(s"Flight: ${f.origin}->${f.dest} | Timestamp: ${f.timestamp} | Delay: ${f.depDelay}")
// }

// 3. Rebuild airport ID map from FULL dataset
val airportCodes = flights.flatMap(f => Seq(f.origin, f.dest)).distinct().collect()


val airportIdMap = airportCodes.zipWithIndex.map { case (code, idx) => 
  (code, idx.toLong) 
}.toMap
println(s"Airport ID map size: ${airportIdMap.size}")

// val bcAirportIds = spark.sparkContext.broadcast(airportIdMap)

// 4. Debug edge creation with coalesced timestamps
val windowedEdges = flights.map { f =>
  val srcId = airportIdMap(f.origin)
  val dstId = airportIdMap(f.dest)
  val windowStart = f.timestamp.getTime - (f.timestamp.getTime % (4 * 60 * 60 * 1000))
  (windowStart, Edge(srcId, dstId, f.depDelay))
}

// 5. Print edges with type annotations
println("First 10 windowed edges:")
windowedEdges.take(10).foreach { 
  case (window: Long, edge: Edge[Double]) =>
    val windowTime = new Timestamp(window).toString
    val origin = airportIdMap.find(_._2 == edge.srcId).map(_._1).getOrElse("UNKNOWN")
    val dest = airportIdMap.find(_._2 == edge.dstId).map(_._1).getOrElse("UNKNOWN")
    println(s"$windowTime: ${origin}(${edge.srcId}) -> ${dest}(${edge.dstId}) | Delay: ${edge.attr} mins")
}



In [0]:
%scala

// ~~~~~~~~~~` VALUES & VAR
val edgesByWindow = windowedEdges.groupByKey()

val vertices = airportCodes.map(code => (airportIdMap(code), 0.0))

val verticesRDD = spark.sparkContext.parallelize(vertices)

var currentVertices: RDD[(VertexId, Double)] = verticesRDD
val windowOrder = edgesByWindow.keys.collect().sorted



In [0]:
%scala

// ~~~~~~~~~~~~~~~~~~~ CLASSES


// Class for saving results
case class VertexResult(
    id: Long, 
    delay_load: Double, 
    window_start: Timestamp
)

// Define message type to track sum
case class Message(sumDelay: Double, count: Long)

// ~~~~~~~~~~~~~~~~~~~ FUNCTIONS

// Vertex program definition
def vProg(vertexId: VertexId, state: Double, msg: Message): Double = {
    if (msg.count == 0) state else {
        val avgDelay = msg.sumDelay / msg.count
        state * 0.1 + avgDelay * 0.8  // Only update delay_load
    }
}
// Message sending func
def sendMsg(triplet: EdgeTriplet[Double, Double]): Iterator[(VertexId, Message)] = {  // Vertex data is now Double
  val scaledDelay = triplet.attr * 0.5
  Iterator((triplet.dstId, Message(scaledDelay, 1L)))  // No prop value
}

// Message merging function
def mergeMsg(a: Message, b: Message): Message = {
    Message(
        a.sumDelay + b.sumDelay, 
        a.count + b.count
    )
}

In [0]:
%scala
windowOrder.foreach { windowStart =>

  println(s"Window starting ${windowStart}")
  val edges = edgesByWindow.lookup(windowStart).flatMap(_.iterator)
  
  val graph = Graph(currentVertices, spark.sparkContext.parallelize(edges))
  
  val updatedGraph = graph.pregel[Message](
    Message(0.0, 0L),  // Initial message
    3,                       // Max iterations (neighbors of neighbors of neighbors)
    EdgeDirection.Out        // Active direction
  )(
    vProg _,    // Convert methods to function values
    sendMsg _,
    mergeMsg _
  )

  currentVertices = updatedGraph.vertices.cache()

  // Save results to Parquet
  currentVertices
    .map { case (id, delay) =>
      VertexResult(id, delay, new Timestamp(windowStart))
    }
    .toDF()
    .write
    .mode("append")
    .parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices.parquet")
}

In [0]:
%scala


// ~~~~~~~~~~~~ PROCESSING LOOP
windowOrder.foreach { windowStart =>
  println(s"Window starting ${windowStart}")

  val edges = edgesByWindow.lookup(windowStart).flatMap(_.iterator)
  
  val graph = Graph(verticesRDD, spark.sparkContext.parallelize(edges))  // verticesRDD is now RDD[(VertexId, Double)]
  
  val updatedGraph = graph.pregel[Message](
    initialMsg = Message(0.0, 0L),  // Initial message without sumProp
    maxIterations = 3,
    activeDirection = EdgeDirection.Out
  )(
    vProg _,    // Updated function signature
    sendMsg _,
    mergeMsg _
  )

  val currentVertices = updatedGraph.vertices
  
  // Save results (without prop_delayed)
  currentVertices
    .map { case (id, delay) =>  // Single value now
      VertexResult(id, delay, new Timestamp(windowStart))
    }
    .toDF()
    .write
    .mode("append")
    .parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices_delay_only.parquet")
}


## try to optimize for efficiency

In [0]:
%scala
import org.apache.spark.HashPartitioner


In [0]:
%scala
// Define message type to track sum and count for averages
case class Message(sumProp: Double, sumDelay: Double, count: Long)

// Class for saving results
case class VertexResult(
    id: Long, 
    delay_load: Double, 
    prop_delayed: Double, 
    window_start: Timestamp
)


// Vertex program definition
def vProg(vertexId: VertexId, state: (Double, Double), msg: Message): (Double, Double) = {
    println(s"vProg called for vertex $vertexId with state $state and msg $msg")
    if (msg.count == 0) state else {
        val avgProp = msg.sumProp / msg.count
        val avgDelay = msg.sumDelay / msg.count
        val newDelay = state._1 * 0.1 + avgDelay * 0.8
        val newProp = avgProp
        (newDelay, newProp)
    }
}

// Message sending function
def sendMsg(triplet: EdgeTriplet[(Double, Double), Double]): Iterator[(VertexId, Message)] = {
  println(s"sendMsg: src=${triplet.srcId}, dst=${triplet.dstId}, attr=${triplet.attr}")
    val srcProp = triplet.srcAttr._2
    val scaledDelay = triplet.attr * 0.5
    Iterator((triplet.dstId, Message(scaledDelay, 1L)))
}

// Message merging function
def mergeMsg(a: Message, b: Message): Message = {
    Message(
        a.sumProp + b.sumProp,
        a.sumDelay + b.sumDelay, 
        a.count + b.count
    )
}

### optimized with no prop

In [0]:
%scala

// ~~~~~~~~~~~~~~~~~~~ CLASSES


// Class for saving results
case class VertexResult(
    id: Long, 
    delay_load: Double, 
    window_start: Timestamp
)

// Define message type to track sum
case class Message(sumDelay: Double, count: Long)

// ~~~~~~~~~~~~~~~~~~~ FUNCTIONS

// Vertex program definition
def vProg(vertexId: VertexId, state: Double, msg: Message): Double = {
    if (msg.count == 0) state else {
        val avgDelay = msg.sumDelay / msg.count
        state * 0.1 + avgDelay * 0.8  // Only update delay_load
    }
}
// Message sending func
def sendMsg(triplet: EdgeTriplet[Double, Double]): Iterator[(VertexId, Message)] = {  // Vertex data is now Double
  val scaledDelay = triplet.attr * 0.5
  Iterator((triplet.dstId, Message(scaledDelay, 1L)))  // No prop value
}

// Message merging function
def mergeMsg(a: Message, b: Message): Message = {
    Message(
        a.sumDelay + b.sumDelay, 
        a.count + b.count
    )
}

In [0]:
%scala
val numWindows = windowOrder.length // used previous value 
val numPartitions = math.max(numWindows, spark.sparkContext.defaultParallelism)



In [0]:
%scala
// // 1. Partition edges by window upfront


val windowedEdgesPartitioned = flights
  .map { f =>
    val windowStart = f.timestamp.getTime - (f.timestamp.getTime % (4 * 60 * 60 * 1000))
    (windowStart, Edge(airportIdMap(f.origin), airportIdMap(f.dest), f.depDelay))
  }
  .partitionBy(new HashPartitioner(numPartitions))
  .persist()

// 2. Group partitioned edges

val edgesByWindow = windowedEdgesPartitioned.groupByKey()

val vertices = airportCodes.map(code => (airportIdMap(code), 0.0))

val verticesRDD = spark.sparkContext.parallelize(vertices)

var currentVertices: RDD[(VertexId, Double)] = verticesRDD

val windowOrder = edgesByWindow.keys.collect().sorted


In [0]:
%scala

// ~~~~~~~~~~~~~~~~~~~ Updated Processing Loop ~~~~~~~~~~~~~~~~~~~
windowOrder.zipWithIndex.foreach { case (windowStart, idx) =>
  println(s"Window starting ${windowStart}")

  val edges = edgesByWindow.lookup(windowStart).flatMap(_.iterator)
  
  if (edges.nonEmpty) {
    val edgesRDD = spark.sparkContext.parallelize[Edge[Double]](edges, numPartitions)  // Type specified
    val graph = Graph(currentVertices, edgesRDD)
  
    val updatedGraph = graph.pregel[Message](
      Message(0.0, 0L),  // Correct initialization
      3,
      EdgeDirection.Out
    )(
      vProg,
      sendMsg,
      mergeMsg
    )

    currentVertices = updatedGraph.vertices
      .mapValues(v => v)
      .cache()
    
    // Checkpoint every 6 windows
    if (idx % 6 == 0) {
      currentVertices.checkpoint()
      currentVertices.count()
    }
  }
  
  // Save results
  currentVertices
    .map { case (id, delay) =>
      VertexResult(id, delay, new Timestamp(windowStart))
    }
    .toDF()
    .write
    .mode("append")
    .parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices_delay_only_partitioned.parquet")
}

In [0]:
%scala
spark.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")


In [0]:
%scala
// ~~~~~~~~~~~~~~~~~~~ Updated Processing Loop ~~~~~~~~~~~~~~~~~~~
windowOrder.zipWithIndex.foreach { case (windowStart, idx) =>
  println(s"Window starting ${windowStart}")

  val edges = edgesByWindow.lookup(windowStart).flatMap(_.iterator)
  
  if (edges.nonEmpty) {
    val edgesRDD = spark.sparkContext.parallelize[Edge[Double]](edges, numPartitions)  // Type specified
    val graph = Graph(currentVertices, edgesRDD)
  
    val updatedGraph = graph.pregel[Message](
      Message(0.0, 0L),  // Correct initialization
      3,
      EdgeDirection.Out
    )(
      vProg _,
      sendMsg _,
      mergeMsg _
    )

    currentVertices = updatedGraph.vertices
      .mapValues(v => v)
      .cache()
    
    // Checkpoint every 6 windows
    if (idx % 6 == 0) {
      currentVertices.checkpoint()
      currentVertices.count()
    }
  }
  
  // Save results
  currentVertices
    .map { case (id, delay) =>
      VertexResult(id, delay, new Timestamp(windowStart))
    }
    .toDF()
    .write
    .mode("append")
    .parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices_delay_only_partitioned.parquet")
}

In [0]:
%scala
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{SparkSession, DataFrame}
import org.apache.spark.graphx.PartitionStrategy.EdgePartition2D
import java.sql.Timestamp
import scala.collection.mutable.ArrayBuffer

// 1. Classes
case class VertexResult(id: Long, delay_load: Double, window_start: Timestamp)
case class Message(sumDelay: Double, count: Long)
case class Flight(origin: String, dest: String, depDelay: Double, timestamp: Timestamp)

// 2. Pregel Functions (top-level)
def vProg(vertexId: VertexId, state: Double, msg: Message): Double = 
  if (msg.count == 0) state else state * 0.1 + (msg.sumDelay / msg.count) * 0.8

def sendMsg(triplet: EdgeTriplet[Double, Double]): Iterator[(VertexId, Message)] = 
  Iterator((triplet.dstId, Message(triplet.attr * 0.5, 1L)))

def mergeMsg(a: Message, b: Message): Message = 
  Message(a.sumDelay + b.sumDelay, a.count + b.count)

// 3. Main Code
val spark = SparkSession.builder()
  .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
  .config("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
  .getOrCreate()

// 4. Data Loading & Preprocessing (unchanged)
// ...

// 5. Edge Processing with Optimized Grouping

val windowedEdges = flights
  .map { f =>
    val windowStart = f.timestamp.getTime - (f.timestamp.getTime % (4 * 60 * 60 * 1000))
    (windowStart, Edge(airportIdMap(f.origin), airportIdMap(f.dest), f.depDelay))
  }
  .aggregateByKey(ArrayBuffer.empty[Edge[Double]])(
    (buf, edge) => buf += edge,
    (buf1, buf2) => buf1 ++= buf2
  )
  .mapValues(_.toArray)
  .partitionBy(new HashPartitioner(numPartitions))
  .persist()

val edgesByWindow = windowedEdges.groupByKey()
val windowOrder = edgesByWindow.keys.collect().sorted

// 6. Batched Processing
windowOrder.grouped(6).foreach { windowBatch =>
  windowBatch.foreach { windowStart =>
    println(s"Window starting ${windowStart}")

    val edges = edgesByWindow.lookup(windowStart).flatMap(identity)
    if (edges.nonEmpty) {
      val edgesRDD = spark.sparkContext.parallelize(edges, numPartitions).flatMap(_.toSeq)
      val graph = Graph(currentVertices, edgesRDD).partitionBy(EdgePartition2D)
      val updatedGraph = graph.pregel[Message](
        Message(0.0, 0L), 2, EdgeDirection.Out
      )(vProg, sendMsg, mergeMsg)
      currentVertices = updatedGraph.vertices.cache()
    }
  }

  // Direct Parquet Write
  currentVertices
    .flatMap { case (id, delay) =>
      windowBatch.map(ws => VertexResult(id, delay, new Timestamp(ws)))
    }
    .toDF()
    .write
    .partitionBy("window_start")
    .mode("append")
    .parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices_delay_only_partitioned.parquet")
}

In [0]:
%scala
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{SparkSession, DataFrame}
import java.sql.Timestamp
import scala.collection.mutable.ArrayBuffer

// 1. Classes
case class VertexResult(id: Long, delay_load: Double, window_start: Timestamp)
case class Message(sumDelay: Double, count: Long)
case class Flight(origin: String, dest: String, depDelay: Double, timestamp: Timestamp)

// 2. Top-level Pregel functions (no object/closure)
def vProg(vertexId: VertexId, state: Double, msg: Message): Double = 
  if (msg.count == 0) state else state * 0.1 + (msg.sumDelay / msg.count) * 0.8

def sendMsg(triplet: EdgeTriplet[Double, Double]): Iterator[(VertexId, Message)] = 
  Iterator((triplet.dstId, Message(triplet.attr * 0.5, 1L)))

def mergeMsg(a: Message, b: Message): Message = 
  Message(a.sumDelay + b.sumDelay, a.count + b.count)


// 4. Data loading (unchanged)
val flights: RDD[Flight] = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr_v2.parquet/")
  .na.drop(Seq("ORIGIN", "DEST"))
  .withColumn("timestamp", coalesce(col("actual_arr_utc"), col("sched_arr_utc")).cast("timestamp"))
  .filter(col("timestamp").isNotNull)
  .withColumnRenamed("ORIGIN", "origin")
  .withColumnRenamed("DEST", "dest")
  .na.fill(Map("DEP_DELAY" -> 0))
  .withColumnRenamed("DEP_DELAY", "depDelay")
  .as[Flight]
  .rdd

// 5. Airport ID mapping
val airportCodes = flights.flatMap(f => Seq(f.origin, f.dest)).distinct().collect()
val airportIdMap = airportCodes.zipWithIndex.map { case (code, idx) => (code, idx.toLong) }.toMap

// 6. Edge processing
val numPartitions = 200
val windowedEdges = flights
  .map { f =>
    val windowStart = f.timestamp.getTime - (f.timestamp.getTime % (4 * 60 * 60 * 1000))
    (windowStart, Edge(airportIdMap(f.origin), airportIdMap(f.dest), f.depDelay))
  }
  .reduceByKey((a, b) => a) // Deduplicate edges per window
  .partitionBy(new org.apache.spark.HashPartitioner(numPartitions))
  .persist()

// 7. Initialize vertices
val vertices = airportCodes.map(code => (airportIdMap(code), 0.0))
var currentVertices: RDD[(VertexId, Double)] = spark.sparkContext.parallelize(vertices)

// 8. Process windows in batches
windowOrder.grouped(6).foreach { windowBatch =>
  windowBatch.foreach { windowStart =>
    val edges = windowedEdges.lookup(windowStart)
    if (edges.nonEmpty) {
      val edgesRDD = spark.sparkContext.parallelize(edges, numPartitions)
      val graph = Graph(currentVertices, edgesRDD)
      val updatedGraph = graph.pregel[Message](
        Message(0.0, 0L),
        2,
        EdgeDirection.Out
      )(vProg, sendMsg, mergeMsg)
      currentVertices = updatedGraph.vertices.cache()
    }
  }

  // Write results
  windowBatch.foreach { windowStart =>
    currentVertices
      .map { case (id, delay) => VertexResult(id, delay, new Timestamp(windowStart)) }
      .toDF()
      .write
      .mode("append")
      .partitionBy("window_start")
      .parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices_delay_only_partitioned.parquet")
  }
}


In [0]:
%scala
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{SparkSession, DataFrame}
import org.apache.spark.graphx.PartitionStrategy.EdgePartition2D
import java.sql.Timestamp
import scala.collection.mutable.ArrayBuffer

// 1. Classes
case class VertexResult(id: Long, delay_load: Double, window_start: Timestamp)
case class Message(sumDelay: Double, count: Long)
case class Flight(origin: String, dest: String, depDelay: Double, timestamp: Timestamp)

// 2. Pregel Functions (top-level)
def vProg(vertexId: VertexId, state: Double, msg: Message): Double = 
  if (msg.count == 0) state else state * 0.1 + (msg.sumDelay / msg.count) * 0.8

def sendMsg(triplet: EdgeTriplet[Double, Double]): Iterator[(VertexId, Message)] = 
  Iterator((triplet.dstId, Message(triplet.attr * 0.5, 1L)))

def mergeMsg(a: Message, b: Message): Message = 
  Message(a.sumDelay + b.sumDelay, a.count + b.count)

// 3. Main Code
val spark = SparkSession.builder()
  .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
  .config("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
  .getOrCreate()

// 4. Data Loading & Preprocessing (unchanged)
// ...

// 5. Edge Processing with Optimized Grouping

val windowedEdges = flights
  .map { f =>
    val windowStart = f.timestamp.getTime - (f.timestamp.getTime % (4 * 60 * 60 * 1000))
    (windowStart, Edge(airportIdMap(f.origin), airportIdMap(f.dest), f.depDelay))
  }
  .aggregateByKey(ArrayBuffer.empty[Edge[Double]])(
    (buf, edge) => buf += edge,
    (buf1, buf2) => buf1 ++= buf2
  )
  .mapValues(_.toArray)
  .partitionBy(new HashPartitioner(numPartitions))
  .persist()

val edgesByWindow = windowedEdges.groupByKey()
val windowOrder = edgesByWindow.keys.collect().sorted

// 6. Batched Processing
windowOrder.grouped(6).foreach { windowBatch =>
  windowBatch.foreach { windowStart =>
    println(s"Window starting ${windowStart}")

    val edges = edgesByWindow.lookup(windowStart).flatMap(identity)
    if (edges.nonEmpty) {
      val edgesRDD = spark.sparkContext.parallelize(edges, numPartitions).flatMap(_.toSeq)
      val graph = Graph(currentVertices, edgesRDD).partitionBy(EdgePartition2D)
      val updatedGraph = graph.pregel[Message](
        Message(0.0, 0L), 2, EdgeDirection.Out
      )(vProg, sendMsg, mergeMsg)
      currentVertices = updatedGraph.vertices.cache()
    }
  }

  // Direct Parquet Write
  currentVertices
    .flatMap { case (id, delay) =>
      windowBatch.map(ws => VertexResult(id, delay, new Timestamp(ws)))
    }
    .toDF()
    .write
    .partitionBy("window_start")
    .mode("append")
    .parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices_delay_only_partitioned.parquet")
}

In [0]:
%scala
val windowedEdges = flights
  .map { f =>
    val windowStart = f.timestamp.getTime - (f.timestamp.getTime % (4 * 60 * 60 * 1000))
    (windowStart, Edge(airportIdMap(f.origin), airportIdMap(f.dest), f.depDelay))
  }
  .aggregateByKey(ArrayBuffer.empty[Edge[Double]])(
    (buf, edge) => buf += edge,
    (buf1, buf2) => buf1 ++= buf2
  )
  .mapValues(_.toArray)
  .partitionBy(new HashPartitioner(numPartitions))
  .persist()

val edgesByWindow = windowedEdges.groupByKey()
val windowOrder = edgesByWindow.keys.collect().sorted

// 6. Batched Processing
windowOrder.grouped(batchSize).foreach { windowBatch =>
  windowBatch.foreach { windowStart =>
    val edges = edgesByWindow.lookup(windowStart).flatMap(identity)
    if (edges.nonEmpty) {
      val edgesRDD = spark.sparkContext.parallelize(edges, numPartitions)
      val graph = Graph(currentVertices, edgesRDD).partitionBy(EdgePartition2D)
      val updatedGraph = graph.pregel[Message](
        Message(0.0, 0L), 2, EdgeDirection.Out
      )(vProg, sendMsg, mergeMsg)
      currentVertices = updatedGraph.vertices.cache()
    }
  }

  // Direct Parquet Write
  currentVertices
    .flatMap { case (id, delay) =>
      windowBatch.map(ws => VertexResult(id, delay, new Timestamp(ws)))
    }
    .toDF()
    .write
    .partitionBy("window_start")
    .mode("append")
    .parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices_delay_only_partitioned.parquet")
}

In [0]:
%scala
import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{SparkSession, DataFrame}
import java.sql.Timestamp
import scala.collection.mutable.ArrayBuffer

// 1. Data Classes
case class VertexResult(id: Long, delay_load: Double, window_start: Timestamp)
case class Message(sumDelay: Double, count: Long)
case class Flight(origin: String, dest: String, depDelay: Double, timestamp: Timestamp)

// 3. Data Loading & Preprocessing
val flights: RDD[Flight] = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr_v2.parquet/")
  .na.drop(Seq("ORIGIN", "DEST"))
  .withColumn("timestamp", coalesce(col("actual_arr_utc"), col("sched_arr_utc")).cast("timestamp"))
  .filter(col("timestamp").isNotNull)
  .withColumnRenamed("ORIGIN", "origin")
  .withColumnRenamed("DEST", "dest")
  .na.fill(Map("DEP_DELAY" -> 0.0))
  .withColumnRenamed("DEP_DELAY", "depDelay")
  .as[Flight]
  .rdd

// 4. Airport ID Mapping
val airportCodes = flights.flatMap(f => Seq(f.origin, f.dest)).distinct().collect()
val airportIdMap = airportCodes.zipWithIndex.map { case (code, idx) => (code, idx.toLong) }.toMap

// 5. Edge Processing with Optimized Grouping
println(s"Partitioning")

val windowedEdges = flights
  .map { f =>
    val windowStart = f.timestamp.getTime - (f.timestamp.getTime % (4 * 60 * 60 * 1000))
    (windowStart, Edge(airportIdMap(f.origin), airportIdMap(f.dest), f.depDelay))
  }
  .aggregateByKey(ArrayBuffer.empty[Edge[Double]])(
    (buf, edge) => buf += edge,
    (buf1, buf2) => buf1 ++= buf2
  )
  .mapValues(_.toArray)
  .partitionBy(new HashPartitioner(numPartitions))
  .persist()

// 6. Vertex Initialization
val vertices = airportCodes.map(code => (airportIdMap(code), 0.0))
var currentVertices: RDD[(VertexId, Double)] = spark.sparkContext.parallelize(vertices)

// 7. Pregel Functions
def vProg(vertexId: VertexId, state: Double, msg: Message): Double = 
  if (msg.count == 0) state else state * 0.1 + (msg.sumDelay / msg.count) * 0.8

def sendMsg(triplet: EdgeTriplet[Double, Double]): Iterator[(VertexId, Message)] = 
  Iterator((triplet.dstId, Message(triplet.attr * 0.5, 1L)))

def mergeMsg(a: Message, b: Message): Message = 
  Message(a.sumDelay + b.sumDelay, a.count + b.count)

// 8. Window Processing with Batched Optimization
val windowOrder = windowedEdges.keys.collect().sorted
val batchSize = 6  // Process 6 windows at a time
val totalBatches = math.ceil(windowOrder.length.toDouble / batchSize).toInt

windowOrder.grouped(batchSize).zipWithIndex.foreach { case (windowBatch, batchIdx) =>
  // 8.1 Process Batch
  windowBatch.foreach { windowStart =>
    println(s"Window starting ${windowStart}")

    windowedEdges.lookup(windowStart).foreach { edges =>
      if (edges.nonEmpty) {
        val graph = Graph(
          currentVertices, 
          spark.sparkContext.parallelize(edges, numPartitions)
        )
        currentVertices = graph.pregel[Message](
          initialMsg = Message(0.0, 0L),
          maxIterations = 2,
          activeDirection = EdgeDirection.Out
        )(vProg, sendMsg, mergeMsg)
          .vertices
          .cache()
      }
    }
  }
  
  // 8.2 Batched Checkpointing
  if (batchIdx % math.max(totalBatches / 5, 1) == 0) {
    currentVertices.checkpoint()
    currentVertices.count()  // Force materialization
  }
  
  // 8.3 Batched Parquet Write
  currentVertices
    .flatMap { case (id, delay) =>
      windowBatch.map(ws => 
        VertexResult(id, delay, new Timestamp(ws))
      )
    }
    .toDF()
    .repartition(numPartitions, col("window_start"))  // Control file count
    .write
    .partitionBy("window_start")
    .mode("append")
    .parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices_partitioned.parquet")
}

// 9. Cleanup
windowedEdges.unpersist()
spark.sparkContext.clearJobGroup()

In [0]:
%scala
// Top-level function definitions (not inside any object or class)
def vProg(vertexId: VertexId, state: Double, msg: Message): Double = {
  if (msg.count == 0) state else {
    val avgDelay = msg.sumDelay / msg.count
    state * 0.1 + avgDelay * 0.8
  }
}

def sendMsg(triplet: EdgeTriplet[Double, Double]): Iterator[(VertexId, Message)] = {
  val scaledDelay = triplet.attr * 0.5
  Iterator((triplet.dstId, Message(scaledDelay, 1L)))
}

def mergeMsg(a: Message, b: Message): Message = {
  Message(a.sumDelay + b.sumDelay, a.count + b.count)
}

with prop:

In [0]:
%scala
// Optimized edge grouping
val edgesByWindow = windowedEdgesPartitioned
  .mapValues(edge => Iterator(edge))
  .reduceByKey(_ ++ _)



// ---- 7. Batched Window Processing ----
val batchSize = 6
val checkpointInterval = math.max(windowOrder.length / 10, 1)

windowOrder.grouped(batchSize).zipWithIndex.foreach { case (windowBatch, batchIdx) =>
  windowBatch.foreach { windowStart =>

    println(s"Window starting ${windowStart}")
    val edges = edgesByWindow.lookup(windowStart).flatMap(identity)
    if (edges.nonEmpty) {
      val edgesRDD = spark.sparkContext.parallelize(edges, numPartitions)
      val graph = Graph(currentVertices, edgesRDD)
      val updatedGraph = graph.pregel[Message](
        Message(0.0, 0L),
        2,  // Fewer iterations for speed
        EdgeDirection.Out
      )(vProg, sendMsg, mergeMsg)
      currentVertices = updatedGraph.vertices.cache()
      // Checkpoint every N batches
      if (batchIdx % checkpointInterval == 0) {
        currentVertices.checkpoint()
        currentVertices.count()
      }
    }
  }
  // Write batch results
  val batchResults = windowBatch.flatMap { windowStart =>
    currentVertices
      .map { case (id, delay) => VertexResult(id, delay, new Timestamp(windowStart)) }
      .collect()
  }
  spark.createDataFrame(batchResults)
    .write
    .partitionBy("window_start")
    .mode("append")
    .parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices_delay_only_partitioned.parquet")
}


In [0]:
%scala
// 1. Define Pregel functions in a serializable object
object PregelFunctions extends Serializable {
  val vProg = (vertexId: VertexId, state: (Double, Double), msg: Message) => {
    if (msg.count == 0) state else {
      val avgProp = msg.sumProp / msg.count
      val avgDelay = msg.sumDelay / msg.count
      val newDelay = state._1 * 0.1 + avgDelay * 0.8
      (newDelay, avgProp)
    }
  }

  val sendMsg = (triplet: EdgeTriplet[(Double, Double), Double]) => {
    val isDelayed = if (triplet.attr > 15) 1.0 else 0.0
    Iterator((triplet.dstId, Message(isDelayed, triplet.attr * 0.5, 1L)))
  }

  val mergeMsg = (a: Message, b: Message) => {
    Message(a.sumProp + b.sumProp, a.sumDelay + b.sumDelay, a.count + b.count)
  }
}


spark.conf.set("spark.sql.parquet.compression.codec", "snappy")

// 3. Calculate dynamic partitioning
val avgEdgesPerWindow = 42373908 / windowOrder.length
val numPartitions = math.ceil(avgEdgesPerWindow / 100000).toInt  // ~100k edges/partition

// 4. Process windows in batches (6 windows per batch)
val batchSize = 6
windowOrder.grouped(batchSize).zipWithIndex.foreach { case (windowBatch, batchIdx) =>
  // Process each window in the batch
  windowBatch.foreach { windowStart =>
    val edges = edgesByWindow.lookup(windowStart).flatMap(_.iterator)
    
    if (edges.nonEmpty) {
      val graph = Graph(
        currentVertices,
        spark.sparkContext.parallelize(edges, numPartitions)
      )
      
      val updatedGraph = graph.pregel[Message](
        initialMsg = Message(0.0, 0.0, 0L),
        maxIterations = 2,  // Reduced from 3
        activeDirection = EdgeDirection.Out
      )(
        PregelFunctions.vProg,
        PregelFunctions.sendMsg,
        PregelFunctions.mergeMsg
      )
      
      currentVertices = updatedGraph.vertices
        .mapValues(v => v)
        .cache()
    }
  }
  
  // Write batched results
  val batchResults = windowBatch.flatMap { windowStart =>
    currentVertices.map { case (id, (delay, prop)) =>
      VertexResult(id, delay, prop, new Timestamp(windowStart))
    }.collect()
  }
  
  spark.createDataFrame(batchResults)
    .write
    .partitionBy("window_start")
    .mode("append")
    .parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices_batched")
  
  // Checkpoint every 10% of batches
  if (batchIdx % math.ceil(windowOrder.length / batchSize / 10).toInt == 0) {
    currentVertices.checkpoint()
    currentVertices.count()  // Force materialization
  }
}



In [0]:
%scala


windowOrder.zipWithIndex.foreach { case (windowStart, idx) =>
  val edges = edgesByWindow.lookup(windowStart).flatMap(_.iterator)
  
  if (edges.nonEmpty) {
    val graph = Graph(currentVertices, spark.sparkContext.parallelize(edges, numPartitions))
  
  val updatedGraph = graph.pregel[Message](
    Message(0.0, 0.0, 0L),  // Initial message
    3,                       // Max iterations (neighbors of neighbors of neighbors)
    EdgeDirection.Out        // Active direction
  )(
    vProg _,    // Convert methods to function values
    sendMsg _,
    mergeMsg _
  )

    currentVertices = updatedGraph.vertices
      .mapValues(v => v)
      .cache()
    
    // Checkpoint every 6 windows
    if (idx % 6 == 0) {
      currentVertices.checkpoint()
      currentVertices.count()
    }
  }
  

  // Save results to Parquet
  currentVertices
  .map { case (id, (delay, prop)) =>
    VertexResult(id, delay, prop, new Timestamp(windowStart))
  }
  .toDF()
  .write
  .partitionBy("window_start")
  .mode("append")
  .parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices_partitioned")

}


In [0]:
%scala


windowOrder.zipWithIndex.foreach { case (windowStart, idx) =>
  val edges = edgesByWindow.lookup(windowStart).flatMap(_.iterator)
  
  if (edges.nonEmpty) {
    val graph = Graph(currentVertices, spark.sparkContext.parallelize(edges, numPartitions))
  
  val updatedGraph = graph.pregel[Message](
    Message(0.0, 0.0),  // Initial message
    3,                       // Max iterations (neighbors of neighbors of neighbors)
    EdgeDirection.Out        // Active direction
  )(
    vProg _,    // Convert methods to function values
    sendMsg _,
    mergeMsg _
  )

    currentVertices = updatedGraph.vertices
      .mapValues(v => v)
      .cache()
    
    // Checkpoint every 6 windows
    if (idx % 6 == 0) {
      currentVertices.checkpoint()
      currentVertices.count()
    }
  }
  

  // Save results to Parquet
  currentVertices
  .map { case (id, (delay, prop)) =>
    VertexResult(id, delay, prop, new Timestamp(windowStart))
  }
  .toDF()
  .write
  .partitionBy("window_start")
  .mode("append")
  .parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices_partitioned")

}


## end here

In [0]:
%scala

// // Broadcast the airport ID map
// val bcAirportIds = spark.sparkContext.broadcast(airportIdMap)

// Build vertices
// val vertices = airportCodes.map { code =>
//   (airportIdMap(code), (0.0, 0.2))  // (id, (delayLoad, propDelayed))
// }.toSeq
// val verticesRDD = spark.sparkContext.parallelize(vertices)

// // Build edges with mapped IDs
// val windowedEdges = flights.flatMap { f =>
//   val srcId = bcAirportIds.value.getOrElse(f.origin, -1L)
//   val dstId = bcAirportIds.value.getOrElse(f.dest, -1L)
//   if (srcId == -1L || dstId == -1L) {
//     None  // Filtered out by flatMap
//   } else {
//     val windowStart = f.timestamp.getTime - (f.timestamp.getTime % (4 * 60 * 60 * 1000))
//     Some((windowStart, Edge(srcId, dstId, f.depDelay)))
//   }
// }

// val edgesByWindow = windowedEdges.groupByKey()

// Define message type to track sum and count for averages
case class Message(sumProp: Double, sumDelay: Double, count: Long)

// Class for saving results
case class VertexResult(
    id: Long, 
    delay_load: Double, 
    prop_delayed: Double, 
    window_start: Timestamp
)

// Vertex program definition
def vProg(vertexId: VertexId, state: (Double, Double), msg: Message): (Double, Double) = {
    println(s"vProg called for vertex $vertexId with state $state and msg $msg")
    if (msg.count == 0) state else {
        val avgProp = msg.sumProp / msg.count
        val avgDelay = msg.sumDelay / msg.count
        val newDelay = state._1 * 0.1 + avgDelay * 0.8
        val newProp = avgProp
        (newDelay, newProp)
    }
}

// Message sending function
def sendMsg(triplet: EdgeTriplet[(Double, Double), Double]): Iterator[(VertexId, Message)] = {
  println(s"sendMsg: src=${triplet.srcId}, dst=${triplet.dstId}, attr=${triplet.attr}")
    val srcProp = triplet.srcAttr._2
    val scaledDelay = triplet.attr * 0.5
    Iterator((triplet.dstId, Message(srcProp, scaledDelay, 1L)))
}

// Message merging function
def mergeMsg(a: Message, b: Message): Message = {
    Message(
        a.sumProp + b.sumProp,
        a.sumDelay + b.sumDelay, 
        a.count + b.count
    )
}

var currentVertices: RDD[(VertexId, (Double, Double))] = verticesRDD
val windowOrder = edgesByWindow.keys.collect().sorted

windowOrder.foreach { windowStart =>
  val edges = edgesByWindow.lookup(windowStart).flatMap(_.iterator)
  
  val graph = Graph(currentVertices, spark.sparkContext.parallelize(edges))
  
  val updatedGraph = graph.pregel[Message](
    Message(0.0, 0.0, 0L),  // Initial message
    3,                       // Max iterations
    EdgeDirection.Out        // Active direction
  )(
    vProg _,    // Convert methods to function values
    sendMsg _,
    mergeMsg _
  )

  currentVertices = updatedGraph.vertices
  .mapValues { (v: (Double, Double)) => 
    v  // Identity function with explicit type
  }.cache()

  // Save results to Parquet
  currentVertices
    .map { case (id, (delay, prop)) =>
      VertexResult(id, delay, prop, new Timestamp(windowStart))
    }
    .toDF()
    .write
    .mode("append")
    .parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/vertices")
}

In [0]:
%scala
import org.apache.spark.sql.functions.coalesce

// 1. Load data with coalesced timestamps
val flights: RDD[Flight] = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr_v2.parquet/")
  .na.drop(Seq("ORIGIN", "DEST")) // Filter rows with null origins/destinations first
  .withColumn("timestamp", coalesce(col("actual_arr_utc"), col("sched_arr_utc")).cast("timestamp"))
  .filter(col("timestamp").isNotNull) // Remove rows with both timestamps null
  .withColumnRenamed("ORIGIN", "origin")
  .withColumnRenamed("DEST", "dest")
  .na.fill(Map("DEP_DELAY" -> 0))
  .withColumnRenamed("DEP_DELAY", "depDelay")
  .as[Flight]
  .rdd

// 2. Verify flight data
println(s"Flight count after coalesce: ${flights.count()}")
// flights.limit(5).foreach { f =>
//   println(s"Flight: ${f.origin}->${f.dest} | Timestamp: ${f.timestamp} | Delay: ${f.depDelay}")
// }

// 3. Rebuild airport ID map from FULL dataset
val airportCodes = flights.flatMap(f => Seq(f.origin, f.dest)).distinct().collect()
val airportIdMap = airportCodes.zipWithIndex.map { case (code, idx) => 
  (code, idx.toLong) 
}.toMap
println(s"Airport ID map size: ${airportIdMap.size}")

// 4. Debug edge creation with coalesced timestamps
val windowedEdges = flights.map { f =>
  val srcId = airportIdMap(f.origin)
  val dstId = airportIdMap(f.dest)
  val windowStart = f.timestamp.getTime - (f.timestamp.getTime % (4 * 60 * 60 * 1000))
  (windowStart, Edge(srcId, dstId, f.depDelay))
}

// 5. Print edges with type annotations
println("First 10 windowed edges:")
windowedEdges.take(10).foreach { 
  case (window: Long, edge: Edge[Double]) =>
    val windowTime = new Timestamp(window).toString
    val origin = airportIdMap.find(_._2 == edge.srcId).map(_._1).getOrElse("UNKNOWN")
    val dest = airportIdMap.find(_._2 == edge.dstId).map(_._1).getOrElse("UNKNOWN")
    println(s"$windowTime: ${origin}(${edge.srcId}) -> ${dest}(${edge.dstId}) | Delay: ${edge.attr} mins")
}



In [0]:


# Map airport codes to unique Long IDs
vertex_ids = train0.select("ORIGIN").union(train0.select("DEST")) \
    .distinct().rdd.map(lambda r: r[0]).zipWithUniqueId().collectAsMap()

# Vertices: (id, (delay_load, prop_delayed))
vertices_rdd = sc.parallelize([
    (vertex_ids[row["ORIGIN"]], (0.0, 0.2)) for row in train0.select("ORIGIN").distinct().collect()
])

# Edges: (src_id, dst_id, delay_load, timestamp)
edges_rdd = train0.rdd.map(lambda r: Edge(
    vertex_ids[r["ORIGIN"]], 
    vertex_ids[r["DEST"]], 
    (r["DEP_DELAY"], r["actual_arr_utc"])
))


In [0]:
vertices_rdd.take(5)

In [0]:
edges_rdd.take(5)

In [0]:


def assign_window(ts):
    ts = datetime.fromisoformat(ts)
    window_start = ts - (ts.hour % 4) * timedelta(hours=1)
    return (window_start, window_start + timedelta(hours=4))

windowed_edges = edges_rdd.map(lambda e: (assign_window(e.attr[1]), e)) \
    .groupByKey()  # Group edges by their 4-hour window


def run_pregel(edges, initial_vertices):
    graph = Graph(initial_vertices, edges)
    
    # Vertex update function
    def vprog(vid, old_state, msg):
        old_delay, old_prop = old_state
        avg_prop, avg_delay = msg if msg else (0.0, 0.0)
        new_delay = old_delay * 0.1 + avg_delay * 0.8
        new_prop = avg_prop
        return (new_delay, new_prop)
    
    # Message function (send to destination)
    def send_msg(triplet):
        src_delay, src_prop = triplet.srcAttr
        edge_delay = triplet.attr[0]
        return [(triplet.dstId, (src_prop, edge_delay * 0.5))]
    
    # Merge messages
    def merge_msg(a, b):
        prop_a, delay_a = a
        prop_b, delay_b = b
        return (prop_a + prop_b, delay_a + delay_b)
    
    # Run Pregel
    return graph.pregel(
        initial_msg=(0.0, 0.0), 
        max_iterations=3,
        vprog=vprog,
        sendMsg=send_msg,
        mergeMsg=merge_msg
    )



In [0]:
# Initialize vertex state
current_vertices = vertices_rdd

for window, edges in windowed_edges.collect():
    print(f"Processing {window}")
    
    # Convert edges to GraphX format
    graphx_edges = edges.map(lambda e: Edge(e.srcId, e.dstId, e.attr[0]))
    
    # Run Pregel
    updated_graph = run_pregel(graphx_edges, current_vertices)
    
    # Update vertices for next window
    current_vertices = updated_graph.vertices
    
    # Store results
    current_vertices.map(lambda x: (x[0], x[1][0], x[1][1], window[0])) \
        .toDF(["id", "delay_load", "prop_delayed", "window_start"]) \
        .write.parquet("dbfs:/student-groups/Group_4_1/interim/modeling_checkpoints/pregel_rdd.parquet")

# WORKING VERS RUN FULL

In [0]:
train0_msg = train0.select(
    F.col("ORIGIN").alias("src"),
    F.col("DEST").alias("dst"),
    F.col("actual_arr_utc")
)

In [0]:
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(0.0))
    .withColumn("prop_delayed", F.lit(0.2)) #initialized from EDA
)
vertices_store = vertices.limit(1).withColumn('window_start', F.lit('PASS')).withColumn('window_end', F.lit('PASS'))
# Define edges with raw delay minutes (no normalization)
edges = train0.select(
    F.col("ORIGIN").alias("src"),
    F.col("DEST").alias("dst"),
    F.col("prop_delayed").alias("edge_prop_delayed"),
    F.col("actual_arr_utc").alias("timestamp"), 
    F.col("DEP_DELAY").alias("delay_load")
)


# Add temporal windows (unchanged)
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "4 hours")
).orderBy("timestamp")

edges_with_windows.cache()
edges_with_windows.count()
windows = edges_with_windows.select("window.start", "window.end") \
    .distinct().orderBy('window.start').collect()

for window in windows:
    print(f"Processing window {window}")
    train0_msg=train0_msg.withColumns({'window_start': F.lit(window.start), 'window_end': F.lit(window.end)})

    # 1. Filter edges in the current window
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.lit(window.start)) & 
        (F.col("timestamp") < F.lit(window.end))
        )

    # 2. Build graph using previous state
    g = GraphFrame(vertices, current_edges)

    # 3. Run Pregel with corrected logic
    result = g.pregel \
        .withVertexColumn(
            "new_delay_state",
            F.col("delay_load"),
            F.col("delay_load")*.1 + 
            F.coalesce(Pregel.msg().getItem("avg_delay")*.8, 
                           F.lit(0.0)) # Sum delays
        ) \
        .withVertexColumn(
            "new_prop_delayed",
            F.col("prop_delayed"),
            F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.0))
        ) \
        .sendMsgToDst(
            F.struct(
                Pregel.src("new_prop_delayed").alias("prop"),  # For prop_delayed updates
                (Pregel.edge("delay_load") * .5).alias("delay")  # Delay scaled by src's prop_delayed
            )
        ) \
        .aggMsgs(
            F.struct(
                F.avg(F.when(Pregel.msg().getItem("delay") > 15, 1).otherwise(0)).alias("avg_prop"),  # Average edge_prop_delayed
                F.avg(Pregel.msg().getItem("delay")).alias("avg_delay")  # Average scaled delays
            )
        ) \
        .setMaxIter(3) \
        .run()

    # 4. Update vertices for next window (cap values if needed)
    vertices = result.select(
        "id",
        F.col("new_prop_delayed").alias("prop_delayed"),
        F.col("new_delay_state").alias("delay_load")
    ).cache()
    vertices_store = vertices_store.unionByName(vertices.withColumn('window_start', F.lit(window.start)) \
                         .withColumn('window_end', F.lit(window.end)))
    
    # 5. Force materialization to persist state
    vertices.count()
    vertices_store.cache()
    vertices_store.count()

In [0]:
display(train0.groupBy('FL_DATE').agg(F.countDistinct('ORIGIN').alias('distinct_origin_count')))

In [0]:
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(0.0))
    .withColumn("prop_delayed", F.lit(0.2)) #initialized from EDA
)
vertices_store = vertices.limit(1).withColumn('window_start', F.lit('PASS')).withColumn('window_end', F.lit('PASS'))
# Define edges with raw delay minutes (no normalization)
edges = train0.select(
    F.col("ORIGIN").alias("src"),
    F.col("DEST").alias("dst"),
    F.col("prop_delayed").alias("edge_prop_delayed"),
    F.col("actual_arr_utc").alias("timestamp"), 
    F.col("DEP_DELAY").alias("delay_load")
)


# Add temporal windows (unchanged)
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "4 hours")
).orderBy("timestamp")

edges_with_windows.cache()
edges_with_windows.count()
windows = edges_with_windows.select("window.start", "window.end") \
    .distinct().orderBy('window.start').collect()

for idx, window in enumerate(windows):
    print(f"Processing window {window}")
    
    # Filter edges
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.lit(window.start)) & 
        (F.col("timestamp") < F.lit(window.end))
    )
    
    current_vertices = 

    # 2. Build graph using previous state
    g = GraphFrame(vertices, current_edges)

    # 3. Run Pregel with corrected logic
    result = g.pregel \
        .withVertexColumn(
            "new_delay_state",
            F.col("delay_load"),
            F.col("delay_load")*.1 + 
            F.coalesce(Pregel.msg().getItem("avg_delay")*.8, 
                           F.lit(0.0)) # Sum delays
        ) \
        .withVertexColumn(
            "new_prop_delayed",
            F.col("prop_delayed"),
            F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.0))
        ) \
        .sendMsgToDst(
            F.struct(
                Pregel.src("new_prop_delayed").alias("prop"),  # For prop_delayed updates
                (Pregel.edge("delay_load") * .5).alias("delay")  # Delay scaled by src's prop_delayed
            )
        ) \
        .aggMsgs(
            F.struct(
                F.avg(F.when(Pregel.msg().getItem("delay") > 15, 1).otherwise(0)).alias("avg_prop"),  # Average edge_prop_delayed
                F.avg(Pregel.msg().getItem("delay")).alias("avg_delay")  # Average scaled delays
            )
        ) \
        .setMaxIter(2) \
        .run()

    # 4. Update vertices for next window (cap values if needed)
    vertices.unpersist()

    vertices = result.select(
        "id",
        F.col("new_prop_delayed").alias("prop_delayed"),
        F.col("new_delay_state").alias("delay_load")
    ).cache()
    vertices.count()
    
    # Checkpoint periodically
    if idx % 5 == 0:
        vertices = vertices.checkpoint(eager=True)
    

    vertices.withColumn('window_start', F.lit(window.start)).withColumn('window_end', F.lit(window.end)).write.mode("append").parquet("dbfs:/student-groups/Group_4_1/interim/modeling_checkpoints/pregel.parquet")

In [0]:
from pyspark.sql import functions as F
from graphframes import GraphFrame
from pyspark import StorageLevel
from pyspark.sql.window import Window

# Configure Spark for graph processing
spark.conf.set("spark.sql.shuffle.partitions", "2000")

def optimize_graph_processing(df):
    # Vertex initialization with efficient partitioning
    vertices = (
        train0.select(F.col("ORIGIN").alias("id"))
        .union(train0.select(F.col("DEST").alias("id")))
        .distinct()
        .repartitionByRange(200, "id")  # Range partitioning for locality
        .withColumn("delay_load", F.lit(0.0))
        .withColumn("prop_delayed", F.lit(0.2))
        .persist(StorageLevel.MEMORY_AND_DISK)
    )

    # Edge processing with window-based partitioning
    edges_with_windows = (
        train0.select(
            F.col("ORIGIN").alias("src"),
            F.col("DEST").alias("dst"),
            F.col("prop_delayed").alias("edge_prop_delayed"),
            F.col("actual_arr_utc").alias("timestamp"), 
            F.col("DEP_DELAY").alias("delay_load")
        )
        .withColumn("window", F.window("timestamp", "4 hours"))
        .withColumn("window_start", F.col("window.start"))  # Explicit column
        .repartition("window_start")  # Valid partition column
        .persist(StorageLevel.MEMORY_AND_DISK)
    )

    windows = edges_with_windows.select("window.start", "window.end").distinct().orderBy('start').collect()

    # Batch processing with incremental checkpointing
    for window_row in windows:
        window = (window_row.start, window_row.end)
        window_id = windows.index(window_row)
        
        # Set job group for resource management
        spark.sparkContext.setJobGroup(
            f"window_{window_id}", 
            f"Processing {window[0]} - {window[1]}"
        )

        # Filter edges using partition pruning
        current_edges = edges_with_windows.filter(
            (F.col("timestamp") >= window[0]) & 
            (F.col("timestamp") < window[1])
        )

        # Dynamic vertex pruning
        active_vertices = current_edges.select("src").union(current_edges.select("dst")).distinct()
        active_vertices = active_vertices.withColumnRenamed("src", "id").withColumnRenamed("dst", "id")
        window_vertices = vertices.join(active_vertices, "id", "inner").persist()

        # Build optimized graph structure
        g = GraphFrame(window_vertices, current_edges)

        # Pregel execution with message aggregation
        result = g.pregel \
            .withVertexColumn(
                "new_delay_state",
                F.col("delay_load"),
                F.col("delay_load") * 0.1 + 
                F.coalesce(Pregel.msg().getItem("avg_delay") * 0.8, F.lit(0.0))
            ) \
            .withVertexColumn(
                "new_prop_delayed",
                F.col("prop_delayed"),
                F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.0))
            ) \
            .sendMsgToDst(
                F.struct(
                    Pregel.src("new_prop_delayed").alias("prop"),
                    (Pregel.edge("delay_load") * 0.5).alias("delay")
                )
            ) \
            .aggMsgs(
                F.struct(
                    F.avg(F.when(Pregel.msg().getItem("delay") > 15, 1).otherwise(0)).alias("avg_prop"),
                    F.avg(Pregel.msg().getItem("delay")).alias("avg_delay")
                )
            ) \
            .setMaxIter(2) \
            .run()

        # Update global vertex state
        new_vertices = result.select(
            "id",
            F.col("new_prop_delayed").alias("prop_delayed"),
            F.col("new_delay_state").alias("delay_load")
        ).persist(StorageLevel.MEMORY_AND_DISK)

        # Merge with global state using incremental update
        vertices = vertices.join(
            new_vertices,
            "id",
            "left_outer"
        ).select(
            "id",
            F.coalesce(new_vertices["prop_delayed"], vertices["prop_delayed"]).alias("prop_delayed"),
            F.coalesce(new_vertices["delay_load"], vertices["delay_load"]).alias("delay_load")
        ).persist(StorageLevel.MEMORY_AND_DISK)

        # Checkpoint every 10 windows using reliable storage
        if window_id % 10 == 0:
            checkpoint_path = f"dbfs:/checkpoints/vertices_{window_id}"
            vertices.write.mode("overwrite").parquet(checkpoint_path)
            vertices = spark.read.parquet(checkpoint_path).persist(StorageLevel.MEMORY_AND_DISK)

        # Batch output with coalesced writes
        window_vertices = vertices.withColumn('window_start', F.lit(window[0])) \
                                  .withColumn('window_end', F.lit(window[1])) \
                                  .coalesce(4)  # Reduce output files
        
        window_vertices.write.mode("append").parquet(
            "dbfs:/student-groups/Group_4_1/interim/modeling_checkpoints/pregel.parquet"
        )

        # Cleanup intermediate datasets
        window_vertices.unpersist()
        current_edges.unpersist()
        new_vertices.unpersist()

    return vertices

optimize_graph_processing(train0)

In [0]:
team_BASE_DIR = f"dbfs:/student-groups/Group_4_1"

spark.sparkContext.setCheckpointDir(f"{team_BASE_DIR}/interim")

In [0]:
vertices_store.union(vertices.withColumn('window_start', F.lit(window.start)) \
                         .withColumn('window_end', F.lit(window.end)))

In [0]:
vertices.withColumn('window_start', F.lit(window.start)).withColumn('window_end', F.lit(window.end))

# graveyard

In [0]:
display(train0.orderBy('sched_depart_utc'))

In [0]:
print(f'window: {window}')
display(edges_with_windows.filter(
        (F.col("timestamp") >= F.lit(window.start)) & 
        (F.col("timestamp") < F.lit(window.end))
    ).orderBy(F.desc("timestamp")))

In [0]:
for w in windows:
    print(w)

In [0]:
# Initialize vertices with non-zero values
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(0.001))
    .withColumn("prop_delayed", F.lit(0.001))
)

# Define edges limited to a single day and 10 records
edges = train0.filter(F.col("sched_depart_utc").between("2015-01-01", "2015-01-01 23:59:59")) \
    .select(
        F.col("priorflight_origin").alias("src"),
        F.col("ORIGIN").alias("dst"),
        F.col("prop_delayed").alias("edge_prop_delayed"),
        F.col("sched_depart_utc").alias("timestamp"), 
        F.col("priorflight_depdelay_calc").alias("delay_load")
    ).limit(10)

print(f'Edges:')
edges.show()

edges.cache()


# Use 24-hour windows to group all test edges
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "24 hours")
)

edges_with_windows.cache()
print(f'Edges with windows: ')
edges_with_windows.show()
# Collect the single window
windows = edges_with_windows.select("window.start", "window.end").distinct().collect()

for window in windows:
    print(f"Processing window {window}")
    
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("window.start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )

    print(f'Current edges:')
    current_edges.show()

    # Build graph and run Pregel (same as before)
    # ...

    print("\nFinal vertex states:")
    vertices.orderBy(F.desc("delay_load")).show(10, truncate=False)


In [0]:
display(edges)
display(current_edges)
display(edges_with_windows)

In [0]:
display(current_edges.filter(F.col("dst") == "PPG").limit(10))

In [0]:
display(train0.filter(F.col('priorflight_dest')=='PPG'))

In [0]:
for window in windows[10:13]:
    print(f"Processing window {window}")
    
    # 1. Filter edges in the current window
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("lookback_start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )

    # 2. Precompute source's average edge_prop_delayed
    src_prop = current_edges.groupBy("src").agg(
        F.avg("edge_prop_delayed").alias("_tmp_src_prop")
    )
    
    # 3. Join with vertices and fill nulls
    vertices_window = vertices.join(
        src_prop,
        vertices.id == src_prop.src,
        "left"
    ).fillna(0.0, subset=["_tmp_src_prop"])

    # 4. Build graph
    g = GraphFrame(vertices_window, current_edges)

    # 5. Run Pregel with structured messages
    result = g.pregel \
        .withVertexColumn(
            "new_delay_state",
            F.col("delay_load"),
            (F.col("delay_load") * 0.2) +  # 80% decay
            F.coalesce(Pregel.msg().getItem("sum_delay"), F.lit(0.0))
        ) \
        .withVertexColumn(
            "_tmp_prop",
            F.col("_tmp_src_prop"),
            (F.col("_tmp_prop") * 0.9) +  # 10% decay
            F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.0)) * 0.1
        ) \
        .sendMsgToDst(
            F.struct(
                (Pregel.edge("edge_prop_delayed") * 0.85).alias("prop"),  # Scaled by decay
                (Pregel.edge("delay_load") * Pregel.src("_tmp_prop")).alias("delay")
            )
        ) \
        .aggMsgs(
            F.struct(
                F.avg(Pregel.msg().getItem("prop")).alias("avg_prop"),
                F.avg(Pregel.msg().getItem("delay")).alias("sum_delay")
            )
        ) \
        .setMaxIter(3) \
        .run()
    
    # 7. Show results
    print("\nFinal vertex states:")
    result.orderBy(F.desc('new_delay_state')).show(10, truncate=False)

    # Checkpoint the vertices DataFrame


In [0]:
from pyspark.sql import functions as F
from graphframes import GraphFrame
from graphframes.lib import Pregel

#initialize vertices
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(15)) #initialization super important - maybe better to use a more informed metric
    .withColumn("prop_delayed", F.lit(0.2))
)

#define edges: basically the t-1 flight for each record 
edges = train0.select(
    F.col("priorflight_origin").alias("src"),
    F.col("ORIGIN").alias("dst"),
    F.col("prop_delayed").alias("edge_prop_delayed"),
    F.col("sched_depart_utc").alias("timestamp"), 
    F.col("priorflight_depdelay_calc").alias("delay_load")
)

#non overlapping windows
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "4 hour")  # Non-overlapping by default
)

#collect distinct windows (no lookback_start needed)
windows = edges_with_windows.select("window.start", "window.end") \
    .distinct().orderBy("window.start").collect()

for window in windows[1:13]:  #sample subset of windows
    print(f"Processing window {window}")
    
    # 1. Filter edges strictly within the current window
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("window.start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )

    # 2. Build graph using previous vertex state carried over
    g = GraphFrame(vertices, current_edges)

    # 3. Run Pregel
    result = g.pregel \
        .withVertexColumn(
            "new_delay_state",
            F.col("delay_load"),
            F.col("delay_load")*.5 + 
            F.coalesce(Pregel.msg().getItem("avg_delay"), F.lit(0.0))
        ) \
        .withVertexColumn(
            "new_prop_delayed",
            F.col("prop_delayed"),
            F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.0))
        ) \
        .sendMsgToDst(
            F.struct(
                Pregel.edge("edge_prop_delayed").alias("prop"),
                (Pregel.edge("delay_load")*Pregel.src("new_prop_delayed")).alias("delay_scaled")
            )
        ) \
        .aggMsgs(
            F.struct(
                F.avg(Pregel.msg().getItem("prop")).alias("avg_prop"),  # Avg proportion of delayed flights at the source
                F.avg(Pregel.msg().getItem("delay_scaled")).alias("avg_delay")
            )
        ) \
        .setMaxIter(5) \
        .run()
    
    # 4. Update vertices for next window
    vertices = result.select(
        "id",
        F.col("new_prop_delayed").alias("prop_delayed"),
        F.col("new_delay_state").alias("delay_load")
    ).cache()
    
    # 5. Force materialization
    vertices.count()

    # 6. Show results
    print("\nFinal vertex states:")
    vertices.filter(F.col('id').isin('JFK', 'BOS', 'ORD','LAX','SFO','MSP')).show(truncate=False)


In [0]:
from pyspark.sql import functions as F
from graphframes import GraphFrame
from graphframes.lib import Pregel

#initialize vertices
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(0)) #initialization super important - maybe better to use a more informed metric
    .withColumn("prop_delayed", F.lit(0.2))
)


edges = train0.select(
    F.col("ORIGIN").alias("src"),
    F.col("DEST").alias("dst"),
    F.col("prop_delayed").alias("edge_prop_delayed"),
    F.col("sched_depart_utc").alias("timestamp"), 
    F.col("DEP_DELAY").alias("delay_load")
)

#non overlapping windows
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "4 hour")  # Non-overlapping by default
)

#collect distinct windows (no lookback_start needed)
windows = edges_with_windows.select("window.start", "window.end") \
    .distinct().orderBy("window.start").collect()

for window in windows[1:13]:  #sample subset of windows
    print(f"Processing window {window}")
    
    # 1. Filter edges strictly within the current window
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("window.start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )

    # 2. Build graph using previous vertex state carried over
    g = GraphFrame(vertices, current_edges)

    # 3. Run Pregel
    result = g.pregel \
        .withVertexColumn(
            "new_delay_state",
            F.col("delay_load"),
            F.col("delay_load")*.4 + 
            F.coalesce(Pregel.msg().getItem("avg_delay"), F.lit(0.0))
        ) \
        .withVertexColumn(
            "new_prop_delayed",
            F.col("prop_delayed"),
            F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.0))
        ) \
        .sendMsgToDst(
            F.struct(
                Pregel.edge("edge_prop_delayed").alias("prop"),
                (Pregel.edge("delay_load")*Pregel.src("new_prop_delayed")).alias("delay_scaled")
            )
        ) \
        .aggMsgs(
            F.struct(
                F.avg(Pregel.msg().getItem("prop")).alias("avg_prop"),  # Avg proportion of delayed flights at the source
                F.avg(Pregel.msg().getItem("delay_scaled")).alias("avg_delay")
            )
        ) \
        .setMaxIter(3) \
        .run()
    
    # 4. Update vertices for next window
    vertices = result.select(
        "id",
        F.col("new_prop_delayed").alias("prop_delayed"),
        F.col("new_delay_state").alias("delay_load")
    ).cache()
    
    # 5. Force materialization
    vertices.count()

    # 6. Show results
    print("\nFinal vertex states:")
    vertices.filter(F.col('id').isin('JFK', 'BOS', 'ORD','LAX','SFO','MSP')).show(truncate=False)


EMA version

come back to this when i have acutal edge specific risks 

In [0]:
from pyspark.sql import functions as F
from graphframes import GraphFrame
from graphframes.lib import Pregel

# Initialize vertices with historical_prop
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(0))
    .withColumn("prop_delayed", F.lit(0.2))
    .withColumn("historical_prop", F.lit(0.2))  # Initial historical proportion
)

# Define edges with raw delay minutes
edges = train0.select(
    F.col("priorflight_origin").alias("src"),
    F.col("ORIGIN").alias("dst"),
    F.col("prop_delayed").alias("edge_prop_delayed"),
    F.col("sched_depart_utc").alias("timestamp"), 
    F.col("priorflight_depdelay_calc").alias("delay_load")
)

# Add non-overlapping 4-hour windows
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "4 hours")  # Adjust window size as needed
)

# Collect distinct windows
windows = edges_with_windows.select("window.start", "window.end") \
    .distinct().orderBy("window.start").collect()

for window in windows[100:113]:  # Process windows 1-12
    print(f"Processing window {window}")
    
    # 1. Filter edges in the current window
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("window.start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )

    # 2. Compute current window's average edge_prop_delayed per source
    window_src_prop = current_edges.groupBy("src").agg(
        F.avg("edge_prop_delayed").alias("current_prop")
    )
    
    # 3. Join with vertices to update historical_prop (EMA: 70% new, 30% historical)
    vertices_window = vertices.join(
            window_src_prop,
            vertices.id == window_src_prop.src,
            "left"
        ).withColumn('current_prop', 
                     F.coalesce(F.col("current_prop"),  F.col("historical_prop"))
        ).withColumn(
            "new_historical_prop",
            F.col("historical_prop") * 0.2 + F.col("current_prop") * 0.8  #somewhat influenced by historical but not a lot
        )

    # 4. Build graph with updated historical_prop
    g = GraphFrame(vertices_window, current_edges)

    # 5. Run Pregel using historical_prop in messages
    result = g.pregel \
    .withVertexColumn(
        "new_delay_state",
        F.col("delay_load"),
        F.col("delay_load") * .15 + F.coalesce(
            Pregel.msg().getItem("avg_delay")*Pregel.msg().getItem("avg_prop"), 
            F.lit(0.0))
    ) \
    .withVertexColumn(
            "newhprop",
            F.col("new_historical_prop"), # Retain EMA value
            F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.2))
            ) \
    .sendMsgToDst(
        F.struct(
            Pregel.src("newhprop").alias("prop"),
            (Pregel.edge("delay_load")).alias("delay_scaled")
        )
    ) \
    .aggMsgs(
        F.struct(
            F.avg(Pregel.msg().getItem("prop")).alias("avg_prop"),
            F.avg(Pregel.msg().getItem("delay_scaled")).alias("avg_delay")
        )
    ) \
    .setMaxIter(3) \
    .run()

    # 6. Update vertices for next window
    vertices = result.select(
        "id",
        F.col("newhprop").alias("historical_prop"),  # Carry forward
        F.col("new_delay_state").alias("delay_load")
    ).cache()
    
    # 7. Force materialization
    vertices.count()

    # 8. Show results
    print("\nFinal vertex states:")
    vertices.filter(F.col('id').isin('JFK', 'BOS', 'ORD','LAX','SFO','DFW')).show(truncate=False)


In [0]:
from pyspark.sql import functions as F
from graphframes import GraphFrame
from graphframes.lib import Pregel

# Initialize vertices with historical_prop
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(0))
    .withColumn("prop_delayed", F.lit(0.2))
    .withColumn("historical_prop", F.lit(0.2))  # Initial historical proportion
)

# Define edges with raw delay minutes
edges = train0.select(
    F.col("priorflight_origin").alias("src"),
    F.col("ORIGIN").alias("dst"),
    F.col("prop_delayed").alias("edge_prop_delayed"),
    F.col("sched_depart_utc").alias("timestamp"), 
    F.col("priorflight_depdelay_calc").alias("delay_load")
)

# Add non-overlapping 4-hour windows
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "4 hours")  # Adjust window size as needed
)

# Collect distinct windows
windows = edges_with_windows.select("window.start", "window.end") \
    .distinct().orderBy("window.start").collect()

for window in windows[100:113]:  # Process windows 1-12
    print(f"Processing window {window}")
    
    # 1. Filter edges in the current window
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("window.start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )

    # 2. Compute current window's average edge_prop_delayed per source
    window_src_prop = current_edges.groupBy("src").agg(
        F.avg("edge_prop_delayed").alias("current_prop")
    )
    
    # 3. Join with vertices to update historical_prop (EMA: 70% new, 30% historical)
    vertices_window = vertices.join(
            window_src_prop,
            vertices.id == window_src_prop.src,
            "left"
        ).withColumn('current_prop', 
                     F.coalesce(F.col("current_prop"),  F.col("historical_prop"))
        ).withColumn(
            "new_historical_prop",
            F.col("historical_prop") * 0.2 + F.col("current_prop") * 0.8  # 50% historical, 50% current
        )

    # 4. Build graph with updated historical_prop
    g = GraphFrame(vertices_window, current_edges)

    # 5. Run Pregel using historical_prop in messages
    result = g.pregel \
    .withVertexColumn(
        "new_delay_state",
        F.col("delay_load"),
        (F.col("delay_load") * F.col('new_historical_prop')) +  #some of previous delay load + scaled incoming delay load
        F.coalesce(Pregel.msg().getItem("avg_delay"), F.lit(0.0))
    ) \
    .withVertexColumn(
            "newhprop",
            F.col("new_historical_prop"), # Retain EMA value
            F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.2))
            ) \
    .sendMsgToDst(
        F.struct(
            Pregel.src("newhprop").alias("prop"),
            (Pregel.edge("delay_load") * 8).alias("delay_scaled")
        )
    ) \
    .aggMsgs(
        F.struct(
            F.avg(Pregel.msg().getItem("prop")).alias("avg_prop"),
            F.avg(Pregel.msg().getItem("delay_scaled")).alias("avg_delay")
        )
    ) \
    .setMaxIter(2) \
    .run()

    # 6. Update vertices for next window
    vertices = result.select(
        "id",
        F.col("newhprop").alias("historical_prop"),  # Carry forward
        F.col("new_delay_state").alias("delay_load")
    ).cache()
    
    # 7. Force materialization
    vertices.count()

    # 8. Show results
    print("\nFinal vertex states:")
    vertices.filter(F.col('id').isin('JFK', 'BOS', 'ORD','LAX','SFO','DFW')).show(truncate=False)


In [0]:


# Initialize vertices with historical_prop
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(0.0))
    .withColumn("prop_delayed", F.lit(0.2))
)

# Define edges with raw delay minutes
edges = train0.select(
    F.col("ORIGIN").alias("src"),
    F.col("DEST").alias("dst"),
    F.col("actual_arr_utc").alias("timestamp"), 
    F.col("DEP_DELAY").alias("delay_load").cast('double')
)

# Add non-overlapping 4-hour windows
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "4 hours")  # Adjust window size as needed
)

# Collect distinct windows
windows = edges_with_windows.select("window.start", "window.end") \
    .distinct().orderBy("window.start").collect()



for window in windows[100:113]:  # Process windows 1-12
    print(f"Processing window {window}")

    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("window.start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )
    

    g = GraphFrame(vertices, current_edges)

    result = g.pregel \
    .withVertexColumn(
        "new_delay_state",
        F.col("delay_load"), #initial
        F.col("delay_load") *.2 + F.coalesce(Pregel.msg().getItem("avg_delay"), F.lit(0.0)) #update
    ) \
    .withVertexColumn(
        "new_prop_delayed",
        F.col("prop_delayed"), #initial
        F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.2)) #update
    ) \
    .sendMsgToDst(
    F.struct(
        (Pregel.src("prop_delayed")*.8).alias("prop"),  # Dynamic per window
        (Pregel.edge("delay_load")).alias("delay_scaled")
    )
) \
    .aggMsgs(
        F.struct(
            F.avg(F.when(Pregel.msg().getItem('delay_scaled') > 15, 1).otherwise(0)
                  ).alias("avg_prop"),
            F.avg(Pregel.msg().getItem("delay_scaled")).alias("avg_delay")
        )
    ) \
    .setMaxIter(2) \
    .run()

    # 6. Update vertices for next window
    vertices = result.select(
        "id",
        F.col("new_delay_state").alias("delay_load"),
        F.col('new_prop_delayed').alias('prop_delayed')
    ).cache()
    
    # 7. Force materialization
    vertices.count()

    # 8. Show results
    print("\nFinal vertex states:")
    vertices.filter(F.col('id').isin('JFK', 'BOS', 'ORD','LAX','SFO','DFW')).show(truncate=False)


In [0]:
display(current_edges)

In [0]:
display(vertices)

In [0]:
result = g.pregel \
    .withVertexColumn(
        "new_delay_state",
        F.lit(0.0),
        (F.col("delay_load").cast("double") * 0.2) + 
        F.coalesce(Pregel.msg().getItem("avg_delay"), F.lit(0.0))
    ) \
    .sendMsgToDst(
        F.struct(
            Pregel.edge("window_prop_delayed").alias("prop"),
            (Pregel.edge("delay_load").cast("double") * Pregel.edge("window_prop_delayed")).alias("delay_scaled")
        )
    ) \
    .aggMsgs(
        F.struct(
            F.avg(Pregel.msg().getItem("prop")).alias("avg_prop"),
            F.avg(Pregel.msg().getItem("delay_scaled")).alias("avg_delay")
        )
    ) \
    .setMaxIter(2) \
    .run()


In [0]:


# Initialize vertices with historical_prop
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(0))
    .withColumn("prop_delayed", F.lit(0.2))
    .withColumn("historical_prop", F.lit(0.2))  # Initial historical proportion
)

# Define edges with raw delay minutes
edges = train0.select(
    F.col("priorflight_origin").alias("src"),
    F.col("ORIGIN").alias("dst"),
    F.col("prop_delayed").alias("edge_prop_delayed"),
    F.col("sched_depart_utc").alias("timestamp"), 
    F.col("priorflight_depdelay_calc").alias("delay_load")
)

# Add non-overlapping 4-hour windows
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "4 hours")  # Adjust window size as needed
)

# Collect distinct windows
windows = edges_with_windows.select("window.start", "window.end") \
    .distinct().orderBy("window.start").collect()

for window in windows[100:113]:  # Process windows 1-12
    print(f"Processing window {window}")
    
    # 1. Filter edges in the current window
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("window.start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )

    # 2. Compute current window's average edge_prop_delayed per source
    window_src_prop = current_edges.groupBy("src").agg(
        F.avg("edge_prop_delayed").alias("current_prop")
    )
    
    # 3. Join with vertices to update historical_prop (EMA: 70% new, 30% historical)
    vertices_window = vertices.join(
            window_src_prop,
            vertices.id == window_src_prop.src,
            "left"
        ).withColumn('current_prop', 
                     F.coalesce(F.col("current_prop"),  F.col("historical_prop"))
        ).withColumn(
            "new_historical_prop",
            F.col("historical_prop") * 0.2 + F.col("current_prop") * 0.8  # 50% historical, 50% current
        )

    # 4. Build graph with updated historical_prop
    g = GraphFrame(vertices_window, current_edges)

    # 5. Run Pregel using historical_prop in messages
    result = g.pregel \
    .withVertexColumn(
        "new_delay_state",
        F.col("delay_load"),
        (F.col("delay_load") * .2) +  #some of previous delay load + scaled incoming delay load
        F.coalesce(Pregel.msg().getItem("avg_delay"), F.lit(0.0))
    ) \
    .withVertexColumn(
            "newhprop",
            F.col("new_historical_prop"), # Retain EMA value
            F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.2))
            ) \
    .sendMsgToDst(
        F.struct(
            Pregel.src("newhprop").alias("prop"),
            (Pregel.edge("delay_load") * Pregel.src('newhprop')).alias("delay_scaled")
        )
    ) \
    .aggMsgs(
        F.struct(
            F.avg(Pregel.msg().getItem("prop")).alias("avg_prop"),
            F.avg(Pregel.msg().getItem("delay_scaled")).alias("avg_delay")
        )
    ) \
    .setMaxIter(2) \
    .run()

    # 6. Update vertices for next window
    vertices = result.select(
        "id",
        F.col("newhprop").alias("historical_prop"),  # Carry forward
        F.col("new_delay_state").alias("delay_load")
    ).cache()
    
    # 7. Force materialization
    vertices.count()

    # 8. Show results
    print("\nFinal vertex states:")
    vertices.filter(F.col('id').isin('JFK', 'BOS', 'ORD','LAX','SFO','DFW')).show(truncate=False)


In [0]:
from pyspark.sql import functions as F
from graphframes import GraphFrame
from graphframes.lib import Pregel

# Initialize vertices with historical_prop
vertices = (
    train0.select(F.col("ORIGIN").alias("id"))
    .union(train0.select(F.col("DEST").alias("id")))
    .distinct()
    .withColumn("delay_load", F.lit(0))
    .withColumn("prop_delayed", F.lit(0.2))
    .withColumn("historical_prop", F.lit(0.2))  # Initial historical proportion
)

# Define edges with raw delay minutes
edges = train0.select(
    F.col("ORIGIN").alias("src"),
    F.col("DEST").alias("dst"),
    F.col("prop_delayed").alias("edge_prop_delayed"),
    F.col("sched_depart_utc").alias("timestamp"), 
    F.col("DEP_DELAY").alias("delay_load")
)

# Add non-overlapping 4-hour windows
edges_with_windows = edges.withColumn(
    "window", 
    F.window("timestamp", "8 hours")  # Adjust window size as needed
)

# Collect distinct windows
windows = edges_with_windows.select("window.start", "window.end") \
    .distinct().orderBy("window.start").collect()

for window in windows[100:113]:  # Process windows 1-12
    print(f"Processing window {window}")
    
    # 1. Filter edges in the current window
    current_edges = edges_with_windows.filter(
        (F.col("timestamp") >= F.col("window.start")) & 
        (F.col("timestamp") < F.col("window.end"))
    )

    # 2. Compute current window's average edge_prop_delayed per source
    window_src_prop = current_edges.groupBy("src").agg(
        F.avg("edge_prop_delayed").alias("current_prop")
    )
    
    # 3. Join with vertices to update historical_prop (EMA: 70% new, 30% historical)
    vertices_window = vertices.join(
            window_src_prop,
            vertices.id == window_src_prop.src,
            "left"
        ).withColumn('current_prop', 
                     F.coalesce(F.col("current_prop"),  F.col("historical_prop"))
        )

    # 4. Build graph with updated historical_prop
    g = GraphFrame(vertices_window, current_edges)

    # 5. Run Pregel using historical_prop in messages
    result = g.pregel \
    .withVertexColumn(
        "new_delay_state",
        F.col("delay_load"),
        F.col("delay_load") * .1 + F.coalesce(
            Pregel.msg().getItem("avg_delay"), 
            F.lit(0.0))
    ) \
    .withVertexColumn(
            "newhprop",
            F.col("current_prop"), # Retain EMA value
            F.coalesce(Pregel.msg().getItem("avg_prop"), F.lit(0.2))
            ) \
    .sendMsgToDst(
        F.struct(
            Pregel.src("newhprop").alias("prop"),
            (Pregel.edge("delay_load")*.5).alias("delay_scaled")
        )
    ) \
    .aggMsgs(
        F.struct(
            F.avg(Pregel.msg().getItem("prop")).alias("avg_prop"),
            F.avg(Pregel.msg().getItem("delay_scaled")).alias("avg_delay")
        )
    ) \
    .setMaxIter(3) \
    .run()

    # 6. Update vertices for next window
    vertices = result.select(
        "id",
        F.col("newhprop").alias("historical_prop"),  # Carry forward
        F.col("new_delay_state").alias("delay_load")
    ).cache()
    
    # 7. Force materialization
    vertices.count()

    # 8. Show results
    print("\nFinal vertex states:")
    vertices.filter(F.col('id').isin('JFK', 'BOS', 'ORD','LAX','SFO','DFW')).show(truncate=False)


In [0]:
display(current_edges.filter(F.col("timestamp") >= F.col("window.start")).filter(F.col("timestamp") < F.col("window.end")).filter(F.col('delay_load')>0))

In [0]:
display(train0.filter(F.col('ORIGIN')=='JFK').filter(F.col('TAIL_NUM').isNotNull()))

In [0]:
display(train0.filter(F.col('priorflight_depdelay_calc')<0))

## Personalized PR - scrap

In [0]:

from pyspark.sql.window import Window
from pyspark.sql.types import *

# Vertices (ordered alphabetically)
v = train0.select(F.col("ORIGIN").alias("id")) \
    .union(train0.select(F.col("DEST").alias("id"))) \
    .distinct() \
    .orderBy("id")

# Edges (directed)
e = train0.select(F.col("ORIGIN").alias("src"), F.col("DEST").alias("dst"))

# Build graph
g = GraphFrame(v, e)

# Sources (ordered alphabetically to match vertices)
sources_df = v.select("id").distinct() \
    .orderBy("id") \
    .withColumn("index", F.row_number().over(Window.orderBy("id")) - 1)

sources_flat = sources_df.select("id").rdd.flatMap(lambda x: x).collect()

# Run PPR with aligned sources
pageranked = g.parallelPersonalizedPageRank(
    resetProbability=0.15, 
    sourceIds=sources_flat, 
    maxIter=10
)

broadcast_sources = sc.broadcast(sources_flat)

result_schema = ArrayType(
    StructType([
        StructField("origin", StringType()),
        StructField("score", DoubleType())
    ])
)
def vector_to_dict(vector):
    # Retrieve the broadcasted list
    sources = broadcast_sources.value
    
    # Sort the vector entries and take top 10
    sorted_entries = sorted(
        [(i, float(v)) for i, v in enumerate(vector)], 
        key=lambda x: x[1], 
        reverse=True
    )[:10]
    
    # Map indices to actual source IDs
    return [(sources_flat[i], float(v)) for i, v in sorted_entries]

# Define UDF
vector_to_dict_udf = udf(vector_to_dict, result_schema)




results = pageranked.vertices.withColumn("pagerank_dict", vector_to_dict_udf("pageranks"))


In [0]:
train0_edges = pageranked.edges.groupBy('src','dst').agg(F.sum('weight'))

In [0]:
display(results.filter(F.col('id')=='JFK'))

In [0]:
display(results.filter(F.col('id')=='HYA'))

In [0]:
display(results.filter(F.col('id')=='LAX'))

In [0]:
display(train0_edges)

In [0]:
display(pageranked.edges.filter(F.col('dst')=='JFK').orderBy(F.desc('weight')))

In [0]:
display(pageranked.edges.filter(F.col('src')=='JFK').orderBy(F.desc('weight')))

In [0]:
display(train0_edges.filter(F.col('src')=='JFK').orderBy(F.desc('sum(weight)')))

In [0]:
display(train0.filter(F.col('ORIGIN')=='JFK').groupBy('DEST').count().orderBy(F.col('count').desc()))

In [0]:
display(train0_edges.filter(F.col('dst')=='JFK').orderBy(F.desc('sum(weight)')))

In [0]:
edge_delay_weights = train0.groupBy("ORIGIN", "DEST").agg(
    (F.sum(F.when(F.col("outcome") == 1, 1).otherwise(0)) / F.count("*")).alias("delay_ratio")
)

reversed_edges = edge_delay_weights.select(
    F.col("DEST").alias("src"),  # Incoming airport
    F.col("ORIGIN").alias("dst"),  # Source of delays
    F.col("delay_ratio").alias("delay_weight")  # Probability of delay propagation
)

# Build reversed graph
g_reversed = GraphFrame(v, reversed_edges)
target_airport = "JFK"  # Airport to predict delays for

# Run PPR with target as the teleportation source
pageranked1 = g_reversed.parallelPersonalizedPageRank(
    resetProbability=0.15,
    sourceIds=sources_flat,  # Focus on delays impacting this airport
    maxIter=20
)

# Extract top influencers for the target
def get_top_influencers(pagerank_vector, sources):
    sorted_scores = sorted(
        [(sources[i], float(score)) for i, score in enumerate(pagerank_vector)],
        key=lambda x: -x[1]
    )
    return [x[0] for x in sorted_scores[1:11]]  # Exclude self-score

get_top_influencers_udf = F.udf(
    lambda v: get_top_influencers(v, sources_flat), 
    ArrayType(StringType())
)

results2 = pageranked1.vertices.filter(F.col("id") == target_airport) \
    .withColumn("top_influencers", get_top_influencers_udf("pageranks")) \
    .select("id", "top_influencers")

In [0]:
display(results2.filter(F.col('id')=='JFK'))

In [0]:
display(pageranked.edges.filter(F.col('dst')=='JFK').distinct().orderBy(F.desc('weight')))

In [0]:
display(pageranked.edges.filter(F.col('src')=='JFK').distinct().orderBy(F.desc('weight')))

In [0]:
display(pageranked1.edges.filter(F.col('dst')=='JFK').groupBy('src').agg(F.sum('weight')))

In [0]:
display(pageranked.edges.filter(F.col('dst')=='JFK').groupBy('src').agg(F.sum('weight')))

In [0]:
display(pageranked1.edges.filter(F.col('dst')=='JFK').distinct().orderBy(F.desc('weight')))

In [0]:
display(pageranked1.edges.filter(F.col('src')=='JFK'))

In [0]:
display(pageranked1.vertices.filter(F.col('id')=='JFK'))