# Imports

In [0]:

import pyspark.sql.functions as F
from pyspark.sql.functions import col,isnan, when, count, concat_ws, countDistinct, collect_set, rank, window, avg, hour, udf, isnan, pandas_udf, to_timestamp, lit, PandasUDFType
import matplotlib.pyplot as plt
import pandas as pd
import re
import pytz
from datetime import datetime, timedelta, time
import numpy as np
from pyspark.sql import types

from pyspark.sql.types import StructType, StructField, StringType, IntegerType, StructType, DoubleType, LongType

from pyspark.sql import Window
import seaborn as sns

In [0]:
!pip install python-geohash
import geohash
from geohash import bbox

In [0]:
df=spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat.parquet/")

In [0]:
team_BASE_DIR = f"dbfs:/student-groups/Group_4_1"
spark.sparkContext.setCheckpointDir(f"{team_BASE_DIR}/interim")

#Weather Cleaning


In [0]:


df_interpolate = (
    df
    .withColumn(
        "origin_HourlyWindSpeed",
        F.when(
            F.col("origin_HourlyWindSpeed").isNull(),
            # Extract sustained wind speed from METAR groups
            F.regexp_extract(
                F.col("origin_REM"),
                r'\b(\d{3})(\d{2,3})(?:G(\d{2,3}))?KT\b',  # Regex pattern
                2  # Capture group for sustained wind speed
            ).cast("int")
        ).otherwise(F.col("origin_HourlyWindSpeed"))
    )
    .withColumn(
        "origin_HourlyWindGustSpeed",
        F.when(
            F.col("origin_HourlyWindGustSpeed").isNull(),
            F.greatest(
                # Regular wind gust (G group)
                F.regexp_extract(
                    F.col("origin_REM"),
                    r'\b(\d{3})(\d{2,3})(?:G(\d{2,3}))?KT\b',
                    3
                ).cast("int"),
                # Peak wind gust (PK WND group)
                F.regexp_extract(
                    F.col("origin_REM"),
                    r'PK WND (\d{3})(\d{2,3})/(\d{4})',  # PK WND pattern
                    2  # Capture group for peak wind speed
                ).cast("int")
            )
        ).otherwise(F.col("origin_HourlyWindGustSpeed"))
    )
)


df_interpolate = (df_interpolate \
    .withColumn(
        'origin_HourlyPrecipitation',
        F.when(
            (F.col("origin_HourlyPrecipitation").isNull()) | (F.col("origin_HourlyPrecipitation") == '*'),
            (F.regexp_extract(F.col("origin_REM"), r" P(\d+)", 1).cast("int") * 0.01) # hundredths of inch kept in "remarks" section
        ).otherwise(F.col("origin_HourlyPrecipitation"))
    ) \
    .withColumn('origin_HourlyPrecipitation', F.regexp_replace('origin_HourlyPrecipitation', 'T', '0.01')) \
    .withColumn(
        'origin_HourlyPrecipitation',
        F.regexp_extract('origin_HourlyPrecipitation', r"[0-9]+(\.[0-9]+)?", 0) # Match digits
    ) \
    .withColumn('origin_HourlyPrecipitation', F.col('origin_HourlyPrecipitation').cast(DoubleType())))

In [0]:


def encode_geohash(precision: int):
    @pandas_udf("string")
    def encode(latitudes: pd.Series, longitudes: pd.Series) -> pd.Series:
        def safe_encode(lat, lon):
            try:
                return geohash.encode(lat, lon, precision)
            except Exception:
                return None
        return latitudes.combine(longitudes, safe_encode)
    return encode

geohash_udf = encode_geohash(precision=2)
df_interpolate = df_interpolate.withColumn('geohash', geohash_udf(F.col('origin_LATITUDE'), F.col('origin_LONGITUDE')))
display(df_interpolate)

In [0]:
def coalesce_within_geohash(
    df, 
    target_col, 
    geohash_col="geohash", 
    dt_col="sched_depart_utc", 
    window_size=6
):
    """Fill nulls in `target_col` using the latest non-null value from the same geohash."""
    
    window_spec = (
        Window.partitionBy(geohash_col)
              .orderBy(F.col(dt_col).cast("long"))
              .rowsBetween(-window_size, 0)
    )
    
    return df.withColumn(
        target_col,
        F.last(target_col, ignorenulls=True).over(window_spec)
    )

In [0]:
df_interpolated = df_interpolate.withColumns(
    {"origin_HourlyPrecipitation": df_interpolate["origin_HourlyPrecipitation"].cast("float"),
     "origin_HourlyWindGustSpeed": df_interpolate["origin_HourlyWindGustSpeed"].cast("float"),
     "origin_HourlyWindSpeed": df_interpolate["origin_HourlyWindSpeed"].cast("float"),
     "origin_HourlyDewPointTemperature": df_interpolate["origin_HourlyDewPointTemperature"].cast("float"),
     "origin_HourlyDryBulbTemperature": df_interpolate["origin_HourlyDryBulbTemperature"].cast("float"),
     "origin_HourlyPressureChange": df_interpolate["origin_HourlyPressureChange"].cast("float"),
     "origin_HourlyRelativeHumidity": df_interpolate["origin_HourlyRelativeHumidity"].cast("float"),
     "origin_HourlyWetBulbTemperature": df_interpolate["origin_HourlyWetBulbTemperature"].cast("float"),
     "origin_HourlyVisibility": df_interpolate["origin_HourlyVisibility"].cast("float")
     
     
     })

In [0]:
columns_to_fill = ['origin_HourlyVisibility','origin_HourlyWindSpeed','origin_HourlyDewPointTemperature','origin_HourlyDryBulbTemperature','origin_HourlyPressureChange','origin_HourlyRelativeHumidity','origin_HourlyWetBulbTemperature','origin_HourlyPrecipitation','origin_HourlyWindGustSpeed']

for col in columns_to_fill:
    df_interpolated = coalesce_within_geohash(df_interpolated, col)



In [0]:

null_counts = df_interpolated.filter(F.col('origin_LATITUDE').isNotNull()).select(
    [F.count(F.when(F.col(c).isNull() | F.isnan(c), c)).alias(c) for c in columns_to_fill]
)

null_counts_orig = df_interpolate.filter(F.col('origin_LATITUDE').isNotNull()).select(
    [F.count(F.when(F.col(c).isNull() | F.isnan(c), c)).alias(c) for c in columns_to_fill]
)
display(null_counts_orig.unionByName(null_counts)) #null counts where we have non null loc (obviously, null loc is impossible to fill in by geohash)

In [0]:


@pandas_udf(DoubleType())
def exponential_smoothing_pandas(values: pd.Series) -> pd.Series:
    """Vectorized UDF for exponential smoothing."""
    if values.empty or not pd.api.types.is_numeric_dtype(values):
        return pd.Series([0.0] * len(values))  # Handle edge cases
    return values.ewm(alpha=0.5, ignore_na=True).mean()

def smooth_column_optimized(
    df, 
    col_name, 
    station_col="origin_STATION", 
    dt_col="sched_depart_date_time", 
    window_size=6
):
    """Applies exponential smoothing to remaining nulls"""
    
    # 1. Cast to numeric type and filter nulls
    df = (
        df.withColumn(col_name, F.col(col_name).cast(DoubleType()))
    )
    
    # 2. Define window to collect non-null values
    window_spec = (
        Window.partitionBy(station_col)
              .orderBy(F.col(dt_col).cast("long"))
              .rowsBetween(-window_size, 0)
    )
    
    # 3. Collect ONLY non-null values within the window
    df = df.withColumn(
        "non_null_values",
        F.collect_list(col_name).over(window_spec)
    )
    
    # 4. Apply vectorized UDF and fill nulls
    return (
        df.withColumn("smoothed", exponential_smoothing_pandas("non_null_values"))
          .withColumn(
              col_name, 
              F.coalesce(F.col(col_name), F.col("smoothed"))
          )
          .drop("non_null_values", "smoothed")
    )


In [0]:
columns_to_fill = ['origin_HourlyWindSpeed','origin_HourlyDewPointTemperature','origin_HourlyDryBulbTemperature','origin_HourlyPressureChange','origin_HourlyRelativeHumidity','origin_HourlyWetBulbTemperature','origin_HourlyPrecipitation','origin_HourlyWindGustSpeed','origin_HourlyVisibility']

for col in columns_to_fill:
    df_interpolated = smooth_column_optimized(df_interpolated, col)



In [0]:
filled_cols = ['origin_HourlyVisibility','origin_HourlyWindSpeed','origin_HourlyDewPointTemperature','origin_HourlyDryBulbTemperature','origin_HourlyPressureChange','origin_HourlyRelativeHumidity','origin_HourlyWetBulbTemperature','origin_HourlyPrecipitation','origin_HourlyWindGustSpeed']

null_counts_ema = df_interpolated.filter(F.col('origin_LATITUDE').isNotNull()).select(
    [F.count(F.when(F.col(c).isNull() | F.isnan(c), c)).alias(c) for c in filled_cols]
)


In [0]:
display(null_counts_orig.unionByName(null_counts).unionByName(null_counts_ema))

In [0]:
df_interpolated.write.mode("overwrite").parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned.parquet/")

#Lags 

## Create features

In [0]:
df = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned.parquet/")

In [0]:
df_small = df.filter(F.col('TAIL_NUM').isNotNull()).filter(F.col('sched_depart_utc')>'2019-04-01').orderBy('TAIL_NUM','sched_depart_utc')

In [0]:
display(df_small.select('TAIL_NUM','ORIGIN','DEST','sched_depart_utc','CANCELLED','two_hours_prior_depart_UTC','priorflight_origin','priorflight_dest', 'priorflight_deptime_calc','priorflight_deptime_calc','priorflight_cancelled_true'))

In [0]:
display(df.filter(F.col('TAIL_NUM').isNotNull()).select('TAIL_NUM','ORIGIN','DEST','sched_depart_utc','CANCELLED','two_hours_prior_depart_UTC','priorflight_origin','priorflight_dest', 'priorflight_deptime_calc','priorflight_deptime_calc','priorflight_cancelled_true').orderBy('TAIL_NUM','sched_depart_utc'))

In [0]:
display(df.filter(F.col('TAIL_NUM')=='N102UW').select('ORIGIN','DEST','sched_depart_utc','CANCELLED','two_hours_prior_depart_UTC','priorflight_origin','priorflight_dest', 'priorflight_deptime_calc','priorflight_deptime_calc','priorflight_cancelled_true').orderBy('sched_depart_utc'))

In [0]:
display(df.filter((F.col('CANCELLED') == 1) & (F.col('priorflight_cancelled_true') == 1)).filter(F.col('TAIL_NUM').isNotNull()).select('TAIL_NUM','ORIGIN','DEST','sched_depart_utc','CANCELLED','two_hours_prior_depart_UTC','priorflight_origin','priorflight_dest', 'priorflight_deptime_calc','priorflight_deptime_calc','priorflight_cancelled_true').orderBy('TAIL_NUM','sched_depart_utc'))

In [0]:
"2019-10-27T19:46:00.000+00:00".tzinfo

In [0]:
def to_utc(yyyymmdd, dep_hhmm, arr_hhmm, dep_tz, arr_tz, flight_dur):
    """
    Create UTC timestamp from flights table columns
    yyyymmdd = FL_DATE
    dep_hhmm = CRS_DEP_TIME
    arr_hhmm = CRS_ARR_TIME
    dep_tz = origin_timezone
    arr_tz = dest_timezone
    flight_dur = CRS_ELAPSED_TIME (for sanity check of arrival time)

    Returns UTC time stamp, (cast to string)
    """

    dep_hhmm = int(dep_hhmm)
    arr_hhmm = int(arr_hhmm)

    yyyy,MM,dd = yyyymmdd.split('-')
    yyyy = int(yyyy) # get year
    MM = int(MM) # get month
    dd = int(dd) # get day

    dep_hh = dep_hhmm//100 # get hour
    dep_mm = dep_hhmm%100 # get minute
    if dep_hh == 24:
        dep_hh = 0
        dep_shift = True
    else:
        dep_shift = False

    arr_hh = arr_hhmm//100 # get hour
    arr_mm = arr_hhmm%100
    if arr_hh == 24:
        arr_hh = 0
        arr_shift = True
    else:
        arr_shift = False

    # create datetime variable for departure
    dt_dep = datetime(yyyy,MM,dd,dep_hh,dep_mm)
    if dep_shift:
        dt_dep += timedelta(days=1)
    # apply local time zone
    dep_local = pytz.timezone(dep_tz).localize(dt_dep)
    # convert to UTC
    dep_utc = dep_local.astimezone(pytz.utc)

    # create datetime variable for arrival
    dt_arr = datetime(yyyy,MM,dd,arr_hh,arr_mm)
    if arr_shift:
        dt_arr += timedelta(days=1)
    # apply local time zone
    arr_local = pytz.timezone(arr_tz).localize(dt_arr)
    # convert to UTC
    arr_utc = arr_local.astimezone(pytz.utc)

    if dep_utc > arr_utc:
        arr_utc += timedelta(days=1)

    # # sanity check
    # arr_utc_SC = dep_utc + timedelta(minutes=flight_dur)

    dt_format = "%Y-%m-%dT%H:%M:%S"

    # return UTC datetime, cast to string
    # return (dep_utc.strftime(dt_format), arr_utc.strftime(dt_format), arr_utc_SC.strftime(dt_format))
    return (dep_utc.strftime(dt_format), arr_utc.strftime(dt_format))

schema = StructType([
    StructField("dep_datetime", StringType(), False),
    StructField("arr_datetime", StringType(), False),
])

dt_udf = udf(to_utc, schema)

out = df.withColumn('processed', 
                                 dt_udf(F.col("FL_DATE"), 
                                        F.col("CRS_DEP_TIME"), 
                                        F.col("CRS_ARR_TIME"), 
                                        F.col("origin_timezone"), 
                                        F.col("dest_timezone"), 
                                        F.col("CRS_ELAPSED_TIME"))
                                 ).cache()

cols = [c for c in out.columns if c != "processed"]
cols += ["processed.dep_datetime","processed.arr_datetime"]
out = out.select(cols)

display(out)

In [0]:

df_dated = (df
.withColumn("FL_YMD", F.col("FL_DATE").substr(0,10))
.withColumn(
    "base_date", F.to_date(F.col("FL_YMD"), "yyyy-MM-dd")
).withColumn(
    "arr_time_adj",
    F.when(F.col("CRS_ARR_TIME") == 2400, F.lit("0000")).otherwise(F.col("CRS_ARR_TIME"))
).withColumn(
    "arr_time_str",
    F.lpad("arr_time_adj", 4, "0")
).withColumn(
    "arr_timestamp",
    F.concat(
        F.col("FL_YMD"),
        F.lit("T"),
        F.substring("arr_time_str", 1, 2),
        F.lit(":"),
        F.substring("arr_time_str", 3, 2),
        F.lit(":00")
    )
).withColumn(
    "local_datetime",
    F.to_timestamp("arr_timestamp", "yyyy-MM-dd'T'HH:mm:ss")
).withColumn(
    "arr_utc",
    F.when(
        F.col("CRS_ARR_TIME") == 2400,
        F.to_utc_timestamp(F.date_add("local_datetime", 1), F.col("dest_timezone"))
    ).otherwise(
        F.to_utc_timestamp("local_datetime", F.col("dest_timezone"))
    )
).withColumn(
    "sched_arr_utc",
    F.when(
        F.col("sched_depart_utc") > F.col("arr_utc"),
        F.date_add("arr_utc", 1)
    ).otherwise(F.col("arr_utc"))
).withColumn(
    "actual_depart_utc",
    F.to_utc_timestamp(F.col("sched_depart_utc"), F.col("origin_timezone"))

).select(
    *[c for c in df.columns],
    "sched_arr_utc"
))

In [0]:
df_dated = (df_dated
.withColumn("FL_YMD", F.col("FL_DATE").substr(0,10))
.withColumn(
    "base_date", F.to_date(F.col("FL_YMD"), "yyyy-MM-dd")
).withColumn(
    "arr_time_adj",
    F.when(F.col("ARR_TIME") == 2400, F.lit("0000")).otherwise(F.col("ARR_TIME"))
).withColumn(
    "arr_time_str",
    F.lpad("arr_time_adj", 4, "0")
).withColumn(
    "arr_timestamp",
    F.concat(
        F.col("FL_YMD"),
        F.lit("T"),
        F.substring("arr_time_str", 1, 2),
        F.lit(":"),
        F.substring("arr_time_str", 3, 2),
        F.lit(":00")
    )
).withColumn(
    "local_datetime",
    F.to_timestamp("arr_timestamp", "yyyy-MM-dd'T'HH:mm:ss")
).withColumn(
    "arr_utc",
    F.when(
        F.col("ARR_TIME") == 2400,
        F.to_utc_timestamp(F.date_add("local_datetime", 1), F.col("dest_timezone"))
    ).otherwise(
        F.to_utc_timestamp("local_datetime", F.col("dest_timezone"))
    )
).withColumn(
    "actual_arr_utc",
    F.when(
        F.col("actual_depart_utc") > F.col("arr_utc"),
        F.date_add("arr_utc", 1)
    ).otherwise(F.col("arr_utc"))
    )
)

In [0]:
df_dated = df_dated.drop('arr_time_adj').drop('arr_time_str').drop('arr_timestamp').drop('arr_utc').drop('local_datetime')

In [0]:
df.filter(F.col('origin_LATITUDE').isNull()).count()

In [0]:


def add_lags_optimized(df):
    # Define windows once
    aircraft_window = Window.partitionBy("TAIL_NUM").orderBy('sched_depart_utc')

    # route_window = Window.partitionBy("ORIGIN", "DEST").orderBy("sched_depart_utc").rowsBetween(-10, -1)

    WhenConditions = (
        (F.col("ORIGIN") == F.col("priorflight_dest")) & 
        (F.col("priorflight_sched_deptime") >= F.col("twentysix_hours_prior_depart_UTC"))
    )

    # Precompute all lagged columns in single pass
    lagged_cols = [
        F.lag("CANCELLED").over(aircraft_window).alias("priorflight_cancelled_true"),
        F.lag("ORIGIN").over(aircraft_window).alias("priorflight_origin"),
        F.lag("DEST").over(aircraft_window).alias("priorflight_dest"),
        F.lag("sched_depart_utc").over(aircraft_window).alias("priorflight_sched_deptime"),
        F.lag("actual_depart_utc").over(aircraft_window).alias("priorflight_true_deptime"),
        F.lag("CRS_ELAPSED_TIME").over(aircraft_window).alias("priorflight_sched_elapsed"),
        F.lag("ACTUAL_ELAPSED_TIME").over(aircraft_window).alias("priorflight_true_elapsed"),
        F.lag("DEP_DELAY").over(aircraft_window).alias("priorflight_true_depdelay"),
        F.lag("sched_arr_utc").over(aircraft_window).alias("priorflight_sched_arrtime"),
        F.lag("actual_arr_UTC").over(aircraft_window).alias("priorflight_true_arrtime")
    ]

    valid_prior = WhenConditions & (F.col("priorflight_cancelled_true") == 0)


    # Base transformations
    base_df = (df
        .withColumn("twentysix_hours_prior_depart_UTC", 
                   (F.col("two_hours_prior_depart_UTC") - F.expr("INTERVAL 24 HOURS")).cast("timestamp"))
        .select("*", *lagged_cols)
    )    

    # Core calculations
    result_df = (base_df
        .withColumn("priorflight_sched_elapsed",
            F.when(valid_prior,
                F.expr("INTERVAL 1 MINUTE") * F.col("priorflight_sched_elapsed")
            )
        )

        .withColumn("priorflight_true_elapsed",
                F.when(valid_prior,
                    F.expr("INTERVAL 1 MINUTE")* F.col("priorflight_true_elapsed")
                    ) 
                    
                )

        
        .withColumn("priorflight_depdelay_calc",
            F.when(valid_prior, F.col("priorflight_true_depdelay")).otherwise(F.lit(None))
        )

        .withColumn("priorflight_isdeparted",
            F.when(
                (F.col("priorflight_true_deptime") <= F.col("two_hours_prior_depart_UTC")) &
                valid_prior, 1
            ).otherwise(0)
        )

        .withColumn("priorflight_depdelay_calc",
            F.when(
                (F.col("priorflight_true_deptime") <= F.col("two_hours_prior_depart_UTC")) & valid_prior,
                F.col("priorflight_true_depdelay")
            ).when(
                (F.col("priorflight_sched_deptime") <= F.col("two_hours_prior_depart_UTC")) &
                (F.col("priorflight_true_deptime") > F.col("two_hours_prior_depart_UTC")) &
                valid_prior,
                (F.col("two_hours_prior_depart_UTC").cast('long') - 
                 F.col("priorflight_sched_deptime").cast('long')) / 60
            ).otherwise(F.lit(0.0)) #if not enough info, assume all is well - in line with other logic; edge cases handled later
        )
        .withColumn("priorflight_deptime_calc",  
            F.col("priorflight_sched_deptime") + 
            (F.expr("INTERVAL 1 MINUTE") * F.col("priorflight_depdelay_calc"))
        ) #non-valid prior depdelay calcs will get rewritten over in edge case handling
        
        .withColumn("priorflight_isdelayed_calc",
            F.when(
                (F.col("priorflight_depdelay_calc") >= 15) | 
                (F.col('priorflight_cancelled_true') == 1), 1
            ).otherwise(0)
        )

        .withColumn("priorflight_isarrived_calc",
            F.when(
                (F.col("priorflight_true_arrtime") <= F.col("two_hours_prior_depart_UTC")) &
                valid_prior, 1
            ).otherwise(0)
        )
        .withColumn("priorflight_arr_time_calc",
            F.when(
                F.col("priorflight_isarrived_calc") == 1,
                F.col("priorflight_true_arrtime")
            ).when(
                (F.col("priorflight_isarrived_calc") == 0) &
                (F.col("priorflight_true_deptime") <= F.col("two_hours_prior_depart_UTC")), 
                F.col("priorflight_true_deptime") + F.col("priorflight_sched_elapsed")
            ).otherwise(
                F.col("priorflight_deptime_calc") + F.col("priorflight_sched_elapsed")
            )
        )
        .withColumn("turnaround_time_calc",
            F.when(valid_prior,
                ((F.col("sched_depart_utc").cast("long") - 
                  F.col("priorflight_arr_time_calc").cast("long")) / 60).cast("double")
            ).otherwise(F.lit(None))
        )
        
    ).cache()

    return result_df

# # Execute pipeline
# result = add_lags_optimized(out)
# display(result)


In [0]:
team_BASE_DIR = f"dbfs:/student-groups/Group_4_1"

df = spark.read.parquet(f"{team_BASE_DIR}/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr.parquet")


In [0]:
aircraft_window = Window.partitionBy("TAIL_NUM").orderBy('sched_depart_utc')

    # route_window = Window.partitionBy("ORIGIN", "DEST").orderBy("sched_depart_utc").rowsBetween(-10, -1)

WhenConditions = (
        (F.col("ORIGIN") == F.col("priorflight_dest")) & 
        (F.col("priorflight_sched_deptime") >= F.col("twentysix_hours_prior_depart_UTC"))
    )


In [0]:
df_= df.withColumn("priorflight_carrier", F.lag("OP_UNIQUE_CARRIER").over(aircraft_window))

In [0]:
df_= df_.withColumn("priororigin_mean_dep_delay", F.lag("mean_dep_delay").over(aircraft_window))

In [0]:
df_= df_.withColumn("priororigin_type", F.lag("origin_type").over(aircraft_window))

In [0]:
df_.select('priororigin_mean_dep_delay').limit(100).summary().show()

In [0]:
df.select('route_risk').limit(100).summary().show()

In [0]:
df_.write.mode("overwrite").parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr_v2.parquet/")

## Don't run: imputation

In [0]:

WhenConditions = (
        (F.col("ORIGIN") == F.col("priorflight_dest")) & 
        (F.col("priorflight_sched_deptime") >= F.col("twentysix_hours_prior_depart_UTC"))
    )
valid_prior = WhenConditions & (F.col("priorflight_cancelled_true") == 0)

# partition by route (ORIGIN->DEST)
hours = lambda i: i * 3600
window_spec = Window.partitionBy(F.col("ORIGIN"),F.col("DEST"), F.col("FL_DATE")) \
    .orderBy(F.col("sched_depart_utc").cast("long")
             ) \
        .rangeBetween(-hours(48),0)
# we will eventually get just -4 to -2 hours, but using 0 in the window allows us to
# grab the utc-2 for the 0 hour offset case

df_routes = df_lags.repartition("ORIGIN", "DEST", "FL_DATE")


@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def mean_turnarounds_udf(turnarounds: pd.Series, 
                       act_dep_times: pd.Series, 
                       sched_dep_utc2: pd.Series) -> float:
    d = turnarounds[(act_dep_times < np.max(sched_dep_utc2))].astype(np.float)
    return np.nanmean(d)

# Apply the UDF over the window
df_lags_imputed = df_routes \
    .withColumn("mean_turnaround_calc", 
        F.when(~valid_prior,
        mean_turnarounds_udf(
                F.col("turnaround_time_calc"),
                F.col("actual_depart_utc"),
                F.col("two_hours_prior_depart_UTC")
            ).over(window_spec)).otherwise(F.col("turnaround_time_calc"))
        
        )

df_lags_imputed.cache()
display(df_lags_imputed)

In [0]:
display(df_lags_imputed.filter(F.col('turnaround_time_calc').isNull()).filter(F.col('mean_turnaround_calc').isNotNull()))

In [0]:
df_lags.write.mode("overwrite").parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned.parquet/")

In [0]:
df_lags_imputed

In [0]:
current = df_routes.alias("current")
prior = df_routes.alias("prior")

join_cond = (
    (F.col("current.ORIGIN") == F.col("prior.ORIGIN")) &
    (F.col("current.DEST") == F.col("prior.DEST")) &
    (F.col("current.FL_DATE") == F.col("prior.FL_DATE")) &
    (F.col("prior.actual_depart_utc") <= F.col("current.two_hours_prior_depart_UTC"))
)

result = (
    current.join(prior, join_cond, "left")
    .groupBy("current.*")
    .agg(F.avg("prior.turnaround_time_calc").alias("mean_turnaround_calc"))
    .withColumn(
        "turnaround_time_calc",
        F.when(
            ~valid_prior,  # Use your existing valid_prior condition
            F.col("mean_turnaround_calc")
        ).otherwise(F.col("turnaround_time_calc"))
    )
)

In [0]:
valid_prior = (
    (F.col("ORIGIN") == F.col("priorflight_dest")) &
    (F.col("priorflight_sched_deptime") >= F.col("twentysix_hours_prior_depart_UTC")) &
    (F.col("priorflight_cancelled_true") == 0)
)

hours = lambda i: i * 3600

# window to include prior 72 of flights on same route
window_spec = Window.partitionBy("ORIGIN", "DEST") \
                   .orderBy(F.col("sched_depart_utc").cast("long")) \
                   .rangeBetween(-hours(72),0)

# calculate conditional average
df_optimized = df_lags \
    .withColumn("valid_turnaround",
        F.when(
            valid_prior,
            F.col("turnaround_time_calc")
        ).otherwise(F.lit(None)) #only consider valid prior flights' turnaround times
    ) \
    .withColumn("mean_turnaround",
        F.avg("valid_turnaround").over(window_spec) #take mean over valid turnaround times
    ) \
    .withColumn("turnaround_time_calc",
        F.when(
            (~valid_prior | F.col('turnaround_time_calc').isNull()), 
               F.col("mean_turnaround")
               ) #replace flight turnarounds that have invalid priors to mean over past 72 hours of that route being flown
          .otherwise(F.col("turnaround_time_calc"))
    ).cache()

display(df_optimized)

In [0]:
valid_prior = (
    (F.col("ORIGIN") == F.col("priorflight_dest")) &
    (F.col("priorflight_sched_deptime") >= F.col("twentysix_hours_prior_depart_UTC")) &
    (F.col("priorflight_cancelled_true") == 0)
)

hours = lambda i: i * 3600

# window to include prior 72 of flights on same route
window_spec = Window.partitionBy("ORIGIN", "DEST") \
                   .orderBy(F.col("sched_depart_utc").cast("long")) \
                   .rangeBetween(-hours(72),0)

# calculate conditional average
df_optimized = df_lags \
    .withColumn("valid_turnaround",
        F.when(
            valid_prior,
            F.col("turnaround_time_calc")
        ).otherwise(F.lit(None)) #only consider valid prior flights' turnaround times
    ) \
    .withColumn("mean_turnaround",
        F.avg("valid_turnaround").over(window_spec) #take mean over valid turnaround times
    ) \
    .withColumn("turnaround_time_calc",
        F.when(
            (~valid_prior | F.col('turnaround_time_calc').isNull()), 
               F.col("mean_turnaround")
               ) #replace flight turnarounds that have invalid priors to mean over past 72 hours of that route being flown
          .otherwise(F.col("turnaround_time_calc"))
    ).cache()

display(df_optimized)

In [0]:
        # # Edge case handling - use Erica's code for: 
        # #1) revise a bit for turnaround time calc - avg over route prior flights
        # #2) use original for priorflight depdelay calc - lag mean delay (@prev origin)
        # #3) then update priorflight_isdelayed_calc
        # .withColumn("turnaround_time_calc",
        #     F.when(
        #         (~valid_prior),
        #         F.last("turnaround_time_calc", ignorenulls=True).over(route_window)
        #     ).otherwise(F.col("turnaround_time_calc"))
        # )
        # .withColumn("priorflight_depdelay_calc",
        #     F.when(
        #         (~valid_prior),
        #         F.last("priorflight_depdelay_calc", ignorenulls=True).over(route_window)
        #     ).otherwise(F.col("priorflight_depdelay_calc"))
        # )

        # .withColumn("priorflight_isdelayed_calc",
        #     F.when(
        #         (F.col("priorflight_depdelay_calc") >= 15) | 
        #         (F.col('priorflight_cancelled_true') == 1), 1
        #     ).otherwise(0)
        # )

# Graph

Brainstorming:

Pagerank
-  makes sense to do it within windows; airports most popular in the winter won't be the most popular in the summer, plus reasonable assumption for scheduling vs leakage 
- personalized pr - should be teleportation factor preference for possible dest wrt origin
  - convert to graphx first

In [0]:
from graphframes import GraphFrame


In [0]:
v = df_lags.select(
    F.col("ORIGIN").alias("id")
).union(
    df_lags.select(F.col("DEST").alias("id"))
).distinct()


e = df_lags.select(
    F.col("ORIGIN").alias("src"),
    F.col("DEST").alias("dst")
)

g = GraphFrame(v, e)


In [0]:
# Define fold date ranges based on erica's cv
folds = [
    {"fold": 'train_0', "date_min": "2014-12-31", "date_max": "2015-10-09"},
    {"fold": 'test_0', "date_min": "2015-10-09", "date_max": "2016-07-17"},
    {"fold": 'train_1', "date_min": "2015-08-14", "date_max": "2016-05-21"},
    {"fold": 'test_1', "date_min": "2016-05-21", "date_max": "2017-02-27"},
    {"fold": 'train_2', "date_min": "2016-03-27", "date_max": "2017-01-01"},
    {"fold": 'test_2', "date_min": "2017-01-01", "date_max": "2017-10-10"},
    {"fold": 'train_3', "date_min": "2016-11-08", "date_max": "2017-08-14"},
    {"fold": 'test_3', "date_min": "2017-08-14", "date_max": "2018-05-23"},
    {"fold": 'train_4', "date_min": "2017-06-22", "date_max": "2018-03-27"},
    {"fold": 'test_4', "date_min": "2018-03-27", "date_max": "2018-01-01"}
    ]

In [0]:
def run_pagerank_on_folds(df, date_column='sched_depart_utc', folds=folds):
    results = []
    
    for fold in folds:
        print(f"Processing fold {fold['fold']}")
        
        # Filter vertices for training period
        fold_df = df.filter(
            (F.col(date_column) >= fold['date_min']) & 
            (F.col(date_column) < fold['date_max'])
        )
        
        vertices_df = fold_df.select(F.col("ORIGIN").alias("id")
                 ).union(
                     df_lags.select(F.col("DEST").alias("id"))
                     ).distinct()
        
        edges_df = fold_df.select(
                F.col("ORIGIN").alias("src"),
                F.col("DEST").alias("dst")
            )
        
        # Create GraphFrame for this fold
        g = GraphFrame(vertices_df, edges_df)
        
        # Run PageRank
        pagerank_results = g.pageRank(resetProbability=.15,
                                      maxIter= 10)
        
        # Add fold information
        pagerank_results_with_fold = pagerank_results.vertices.withColumn("fold", F.lit(fold['fold']))
        
        results.append(pagerank_results_with_fold)

    # Combine all results
    all_results = results[0]
    for i in range(1, len(results)):
        all_results = all_results.union(results[i])
        
    return all_results



In [0]:
pr_df= run_pagerank_on_folds(df_lags)

In [0]:
display(pr_df)

In [0]:
results = g.pageRank(resetProbability=0.15, maxIter = 10)


results.vertices.select("id", "pagerank").show()
results.edges.select("src", "dst", "weight").show()


In [0]:
display(results.edges.select('src','dst','weight').distinct().orderBy(F.col('weight').desc()))

In [0]:
display(results.vertices.select("id", "pagerank").orderBy(F.col('pagerank').desc()))


In [0]:
# paths = g.find("(a)-[e1]->(b); (b)-[e2]->(c); (c)-[e3]->(a)")#need this to acocunt for time/immediate t+1 step though

In [0]:
pr_df.write.mode("overwrite").parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/pagerank.parquet/")

## Combine

In [0]:
df = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned.parquet/")

In [0]:
pr = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/pagerank.parquet/")

In [0]:
# Define fold date ranges based on erica's cv
folds = [
    {"fold": 'train_0', "date_min": "2014-12-31", "date_max": "2015-10-09"},
    {"fold": 'test_0', "date_min": "2015-10-09", "date_max": "2016-07-17"},
    {"fold": 'train_1', "date_min": "2015-08-14", "date_max": "2016-05-21"},
    {"fold": 'test_1', "date_min": "2016-05-21", "date_max": "2017-02-27"},
    {"fold": 'train_2', "date_min": "2016-03-27", "date_max": "2017-01-01"},
    {"fold": 'test_2', "date_min": "2017-01-01", "date_max": "2017-10-10"},
    {"fold": 'train_3', "date_min": "2016-11-08", "date_max": "2017-08-14"},
    {"fold": 'test_3', "date_min": "2017-08-14", "date_max": "2018-05-23"},
    {"fold": 'train_4', "date_min": "2017-06-22", "date_max": "2018-03-27"},
    {"fold": 'test_4', "date_min": "2018-03-27", "date_max": "2019-01-01"}
    ]

In [0]:
pr.filter(F.col('id')=='ATW').filter(F.col('fold')==0).show()

In [0]:
window_spec = Window.partitionBy('fold').orderBy('id')

pr_folds = pr.withColumn('row_num', F.row_number().over(window_spec))

In [0]:
window_spec = Window.partitionBy('fold', 'id').orderBy('id')

pr_folds = pr.withColumn('row_num', F.row_number().over(window_spec))

pr_folds = pr_folds.withColumn('row_num', F.when(F.col('row_num') > 1, 2).otherwise(F.col('row_num')))

In [0]:
df.filter(F.col('sched_depart_utc') > folds[0]['date_min']) \
  .filter(F.col('sched_depart_utc') <= folds[0]['date_max']) \
  .select('ORIGIN', 'DEST') \
  .select(F.explode(F.array('ORIGIN', 'DEST')).alias('airport')) \
  .distinct() \
  .count()

In [0]:
df.filter(F.col('sched_depart_utc') > folds[1]['date_min']) \
  .filter(F.col('sched_depart_utc') <= folds[1]['date_max']) \
  .select('ORIGIN', 'DEST') \
  .select(F.explode(F.array('ORIGIN', 'DEST')).alias('airport')) \
  .distinct() \
  .count()

In [0]:
folds_df = spark.createDataFrame(folds)

df_with_folds = df.join(
    folds_df,
    (df['sched_depart_utc'] > folds_df['date_min']) & (df['sched_depart_utc'] <= folds_df['date_max']),
    'inner'
).select(df['*'], folds_df['fold'])

display(df_with_folds)

In [0]:
df_with_folds.filter(F.col('fold').isNull()).count()

In [0]:
pr_folds = pr_folds.withColumn(
    'fold_label',
    when((F.col('fold') == 0) & (F.col('row_num') == 1), 'train_0')
    .when((F.col('fold') == 0) & (F.col('row_num') == 2), 'test_0')
    .when((F.col('fold') == 1) & (F.col('row_num') == 1), 'train_1')
    .when((F.col('fold') == 1) & (F.col('row_num') == 2), 'test_1')
    .when((F.col('fold') == 2) & (F.col('row_num') == 1), 'train_2')
    .when((F.col('fold') == 2) & (F.col('row_num') == 2), 'test_2')
    .when((F.col('fold') == 3) & (F.col('row_num') == 1), 'train_3')
    .when((F.col('fold') == 3) & (F.col('row_num') == 2), 'test_3')
    .when((F.col('fold') == 4) & (F.col('row_num') == 1), 'train_4')
    .when((F.col('fold') == 4) & (F.col('row_num') == 2), 'test_4')
)

display(pr_folds)

In [0]:
df_with_pr = df_with_folds.join(pr_folds.select('id','pagerank','fold_label'), 
                                (df_with_folds['fold'] == pr_folds['fold_label']) & 
                                (df_with_folds['ORIGIN'] == pr_folds['id']), 
                                'left')
df_with_pr.filter(F.col('fold').isNull()).count()

In [0]:
df_with_pr.groupBy('YEAR').count().show()

In [0]:
df_with_pr.filter(F.col('TAIL_NUM').isNotNull()).filter(F.col('pagerank').isNull()).count()

In [0]:
df_with_pr.write.mode('overwrite').parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr.parquet/")

In [0]:
display(df_with_pr.filter(F.col('TAIL_NUM').isNotNull()).filter(F.col('pagerank').isNull()))

In [0]:
display(df_with_pr.filter(F.col('TAIL_NUM').isNotNull()).filter(F.col('pagerank').isNull()))

## last folds

In [0]:
df_with_pr = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr.parquet/")

In [0]:
df = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned.parquet/")

In [0]:
window_spec = Window.partitionBy('fold', 'id').orderBy('id')

pr_folds = pr.withColumn('row_num', F.row_number().over(window_spec))

pr_folds = pr_folds.withColumn('row_num', F.when(F.col('row_num') > 1, 2).otherwise(F.col('row_num')))

pr_folds = pr_folds.withColumn(
    'fold_label',
    when((F.col('fold') == 0) & (F.col('row_num') == 1), 'train_0')
    .when((F.col('fold') == 0) & (F.col('row_num') == 2), 'test_0')
    .when((F.col('fold') == 1) & (F.col('row_num') == 1), 'train_1')
    .when((F.col('fold') == 1) & (F.col('row_num') == 2), 'test_1')
    .when((F.col('fold') == 2) & (F.col('row_num') == 1), 'train_2')
    .when((F.col('fold') == 2) & (F.col('row_num') == 2), 'test_2')
    .when((F.col('fold') == 3) & (F.col('row_num') == 1), 'train_3')
    .when((F.col('fold') == 3) & (F.col('row_num') == 2), 'test_3')
    .when((F.col('fold') == 4) & (F.col('row_num') == 1), 'train_4')
    .when((F.col('fold') == 4) & (F.col('row_num') == 2), 'test_4')
)

display(pr_folds)

In [0]:
pr_folds = pr_folds.filter(F.col('fold_label') != 'test_4')

In [0]:
display(pr_folds.filter(F.col('fold_label')=='train_4'))

In [0]:
pr = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/pagerank.parquet/")

# Folds

In [0]:
def run_pagerank_on_folds(df, date_column='sched_depart_utc', folds=folds):
    results = []
    
    for fold in folds:
        print(f"Processing fold {fold['fold']}")
        
        # Filter vertices for training period
        fold_df = df.filter(
            (F.col(date_column) >= fold['date_min']) & 
            (F.col(date_column) < fold['date_max'])
        )
        
        vertices_df = fold_df.select(F.col("ORIGIN").alias("id")
                 ).union(
                     fold_df.select(F.col("DEST").alias("id"))
                     ).distinct()
        
        edges_df = fold_df.select(
                F.col("ORIGIN").alias("src"),
                F.col("DEST").alias("dst")
            )
        
        # Create GraphFrame for this fold
        g = GraphFrame(vertices_df, edges_df)
        
        # Run PageRank
        pagerank_results = g.pageRank(resetProbability=.15,
                                      maxIter= 10)
        
        # Add fold information
        pagerank_results_with_fold = pagerank_results.vertices.withColumn("fold", F.lit(fold['fold']))
        
        results.append(pagerank_results_with_fold)

    # Combine all results
    all_results = results[0]
    for i in range(1, len(results)):
        all_results = all_results.union(results[i])
        
    return all_results



In [0]:
pr_final = run_pagerank_on_folds(df, folds=blocks)

In [0]:
display(pr_folds)

In [0]:
all_pr_folds = pr_folds.drop('row_num').drop('fold').withColumnRenamed('fold_label','fold').unionByName(pr_final)

In [0]:
display(all_pr_folds)

In [0]:
display(pr_df_full)

In [0]:
folds_df = spark.createDataFrame(blocks)

df_with_folds = df.join(
    folds_df,
    (df['sched_depart_utc'] > folds_df['date_min']) & (df['sched_depart_utc'] <= folds_df['date_max']),
    'inner'
).select(df['*'], folds_df['fold'])

display(df_with_folds)

In [0]:
df_with_folds

In [0]:
df_with_folds_final = df_with_folds.filter(F.col('fold').contains('test'))

In [0]:
pr_final

In [0]:


df_with_pr_final = df_with_folds_final.join(
    pr_final.select('id', 'pagerank', 'fold'),
    (df_with_folds_final['fold'] == pr_final['fold']) & 
    (df_with_folds_final['ORIGIN'] == pr_final['id']),
    'left'
)


In [0]:
pr_df_full  = (df.join(all_pr_folds.filter(F.col('fold')=='train_0').withColumnRenamed('pagerank','train_0').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='test_0').withColumnRenamed('pagerank','test_0').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='train_1').withColumnRenamed('pagerank','train_1').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='test_1').withColumnRenamed('pagerank','test_1').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='train_2').withColumnRenamed('pagerank','train_2').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='test_2').withColumnRenamed('pagerank','test_2').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='train_3').withColumnRenamed('pagerank','train_3').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(dftest3.filter(F.col('fold') == 'test_3').withColumnRenamed('pagerank','test_3').drop('fold'), df['ORIGIN']==dftest3['id']).drop('id')
        .join(dftrain4.filter(F.col('fold')=='train_4').withColumnRenamed('pagerank','train_4').drop('fold'), df['ORIGIN']==dftrain4['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='test_4').withColumnRenamed('pagerank','test_4').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')
        .join(all_pr_folds.filter(F.col('fold')=='test').withColumnRenamed('pagerank','test').drop('fold'), df['ORIGIN']==all_pr_folds['id']).drop('id')

        
        )

In [0]:
df_with_pr_final.filter(F.col('pagerank').isNull()).count()

In [0]:
df_with_pr=df_with_pr.filter(F.col('fold')!='test_4')

In [0]:
df_with_pr_final.count()

In [0]:
df_with_pr = df_with_pr.dropDuplicates()

In [0]:
original_cols = df_with_pr_final.columns
new_columns = original_cols.copy()
new_columns[-1] = 'fold_label'  # Changes the last element from 'fold' to 'fold_label'
df_with_pr_final = df_with_pr_final.toDF(*new_columns)


In [0]:
df_pr = df_with_pr.unionByName(df_with_pr_final)

In [0]:
df_pr.groupBy('YEAR').count().show()

In [0]:
from pyspark.sql.functions import min, max

min_date = df_pr.select(min("FL_DATE")).first()[0]
max_date = df_pr.select(max("FL_DATE")).first()[0]

min_date, max_date

In [0]:
pr_df_full.write.mode("overwrite").parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned_pr.parquet/")

In [0]:
pr_df_full

# Graph Features Exploration

## Community Detection

Idea: community detection. All airports are not concerned about all of each other, but would be concerned about the delays propagating within their communities. 

In [0]:
from graphframes import GraphFrame


In [0]:
df = spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined__timefeat_seasfeat_cleaned.parquet/")

In [0]:
ydf=df.filter(F.col('YEAR')==2015)

In [0]:
v = ydf.select(
    F.col("ORIGIN").alias("id")
).union(
    ydf.select(F.col("DEST").alias("id"))
).distinct()


e = ydf.select(
    F.col("ORIGIN").alias("src"),
    F.col("DEST").alias("dst")
)

g = GraphFrame(v, e)


### Label Propagation

Tl;dr as the spark docs says, literally just groups stuff into 1 label for the most part

In [0]:
result = g.labelPropagation(maxIter=15)
result.groupBy('label').count().show()

In [0]:
e.count()

In [0]:
result.filter(F.col('label')=='1219770712064').show()

### Personalized PR



In [0]:
from pyspark.sql.functions import udf
from pyspark.sql.types import MapType, IntegerType, DoubleType


Main idea: compute the importance of nodes relative to a specific source node or set of nodes. Then, can find metrics for relevant nodes. Here, conceptualizing a node = an airport.



In [0]:
ydf.groupBy("ORIGIN").agg(F.countDistinct("DEST").alias("dest_count")).orderBy(F.desc("dest_count")).show()

In [0]:
from pyspark.sql.window import Window

# Vertices (ordered alphabetically)
v = ydf.select(F.col("ORIGIN").alias("id")) \
    .union(ydf.select(F.col("DEST").alias("id"))) \
    .distinct() \
    .orderBy("id")

# Edges (directed)
e = ydf.select(F.col("ORIGIN").alias("src"), F.col("DEST").alias("dst"))

# Build graph
g = GraphFrame(v, e)


In [0]:

# Sources (ordered alphabetically to match vertices)
sources_df = v.select("id").distinct() \
    .orderBy("id") \
    .withColumn("index", F.row_number().over(Window.orderBy("id")) - 1)

sources_flat = sources_df.select("id").rdd.flatMap(lambda x: x).collect()


In [0]:

# Run PPR with aligned sources
pageranked = g.parallelPersonalizedPageRank(
    resetProbability=0.15, 
    sourceIds=sources_flat, 
    maxIter=10
)


In [0]:
display(pageranked.vertices.filter(F.col('id')=='JFK'))

In [0]:
# Example: Check JFK's PPR scores
jfk_index = sources_flat.index("JFK")
jfk_scores = pageranked.vertices.select("pageranks").collect()[jfk_index]

# Map scores to airport names
sorted_scores = sorted(
    zip(sources_flat, jfk_scores), 
    key=lambda x: x[1], 
    reverse=True
)

# Exclude self-score (JFK)
top_external = sorted_scores[1:11]  # Skip index 0 (JFK itself)
print(top_external)


In [0]:
sources_flat[51]

In [0]:


broadcast_sources = sc.broadcast(sources_flat)

result_schema = ArrayType(
    StructType([
        StructField("origin", StringType()),
        StructField("score", DoubleType())
    ])
)
def vector_to_dict(vector):
    # Retrieve the broadcasted list
    sources = broadcast_sources.value
    
    # Sort the vector entries and take top 10
    sorted_entries = sorted(
        [(i, float(v)) for i, v in enumerate(vector)], 
        key=lambda x: x[1], 
        reverse=True
    )[:10]
    
    # Map indices to actual source IDs
    return [(sources_flat[i], float(v)) for i, v in sorted_entries]

# Define UDF
vector_to_dict_udf = udf(vector_to_dict, result_schema)




results = pageranked.vertices.withColumn("pagerank_dict", vector_to_dict_udf("pageranks"))


In [0]:
[u for u in enumerate(sources_flat) if u[1]=='HYA']

In [0]:
results.printSchema()

In [0]:
def vector_to_map(vector):
    sources = broadcast_sources.value
    indices = vector.indices.tolist()
    values = vector.values.tolist()

    return {sources[i]: float(v) for i, v in zip(indices, values)}

vector_to_map_udf = udf(vector_to_map, MapType(StringType(), DoubleType()))
results = results.withColumn("pagerank_map", vector_to_map_udf(col("pageranks")))


pageranks (VectorType): the pageranks of this vertex from all input source vertices


In [0]:
r

In [0]:
display(pageranked.edges.limit(10))

In [0]:
display(results)

In [0]:
display(results.withColumn('pagerank', F.col('pageranks').cast('string')).limit(5))

In [0]:

broadcast_sources = sc.broadcast(sources_list)

result_schema = ArrayType(
    StructType([
        StructField("origin", StringType()),
        StructField("score", DoubleType())
    ])
)
def vector_to_dict(vector):
    # Retrieve the broadcasted list
    sources = broadcast_sources.value
    
    # Sort the vector entries and take top 10
    sorted_entries = sorted(
        [(i, float(v)) for i, v in enumerate(vector)], 
        key=lambda x: x[1], 
        reverse=True
    )[:10]
    
    # Map indices to actual source IDs
    return [(sources[i], float(v)) for i, v in sorted_entries]

# Define UDF
vector_to_dict_udf = udf(vector_to_dict, result_schema)

# Apply UDF
results = results.withColumn("pagerank_dict", vector_to_dict_udf("pageranks"))
display(results.limit(5)) #sanity check

In [0]:
display(results.filter(F.col('id')=='JFK'))

In [0]:
ydf.filter(F.col("ORIGIN")=="JFK").filter(F.col("DEST")=="HYA").count()

In [0]:
df.filter(F.col("ORIGIN")=="JFK").groupBy("DEST").count().orderBy(F.desc(F.col('count'))).show()

In [0]:
df.filter(F.col("ORIGIN")=="HYA").groupBy("DEST").count().orderBy(F.desc(F.col('count'))).show()

In [0]:
df.filter(F.col("ORIGIN")=="HYA").filter(F.col("DEST")=="JFK").count()

In [0]:
df.filter(F.col("ORIGIN")=="JFK").filter(F.col("DEST")=="HYA").count()

In [0]:
[i for i in enumerate(sources_list) if i[1]=='HYA']

In [0]:
display(results.filter(F.col('id')=='BGM').select(F.map_keys('pagerank_dict').cast("string")))

In [0]:
display(results.limit(5))

In [0]:
results.select('id', 'pagerank_dict')

In [0]:
sources_list[200]

In [0]:
display(results.vertices.withColumn("pagerank_dict", vector_to_dict_udf("pageranks")).limit(5))