# Imports

In [0]:
!pip install python-geohash

In [0]:
import geohash
from geohash import bbox
import pyspark.sql.functions as F
from pyspark.sql.functions import col,isnan, when, count, concat_ws, countDistinct, collect_set, rank, window, avg, hour, udf, isnan, pandas_udf, to_timestamp, lit, udf
import matplotlib.pyplot as plt
import pandas as pd
import re
import pytz
from datetime import datetime, timedelta, time
import numpy as np
from pyspark.sql import types
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, StructType, DoubleType, LongType

from pyspark.sql import Window
import seaborn as sns

In [0]:
stations = spark.read.parquet(f"dbfs:/mnt/mids-w261/datasets_final_project_2022/stations_data/*")


In [0]:
team_BASE_DIR = f"dbfs:/student-groups/Group_4_1"
spark.sparkContext.setCheckpointDir(f"{team_BASE_DIR}/interim")
period = "1y" # on of the following values ("", "3m", "6m", "1y")
df_ = spark.read.parquet(f"{team_BASE_DIR}/interim/join_checkpoints/joined_flights_weather_{period}_v1.parquet")

In [0]:
#for modeling checkpointing
team_BASE_DIR = f"dbfs:/student-groups/Group_4_1"

spark.sparkContext.setCheckpointDir(f"{team_BASE_DIR}/interim")


# Engineering

In [0]:
def to_utc(yyyymmdd, dep_hhmm, arr_hhmm, dep_tz, arr_tz, flight_dur):
    """
    Create UTC timestamp from flights table columns
    yyyymmdd = FL_DATE
    dep_hhmm = CRS_DEP_TIME
    arr_hhmm = CRS_ARR_TIME
    dep_tz = origin_timezone
    arr_tz = dest_timezone
    flight_dur = CRS_ELAPSED_TIME (for sanity check of arrival time)

    Returns UTC time stamp, (cast to string)
    """

    dep_hhmm = int(dep_hhmm)
    arr_hhmm = int(arr_hhmm)

    yyyy,MM,dd = yyyymmdd.split('-')
    yyyy = int(yyyy) # get year
    MM = int(MM) # get month
    dd = int(dd) # get day

    dep_hh = dep_hhmm//100 # get hour
    dep_mm = dep_hhmm%100 # get minute
    if dep_hh == 24:
        dep_hh = 0
        dep_shift = True
    else:
        dep_shift = False

    arr_hh = arr_hhmm//100 # get hour
    arr_mm = arr_hhmm%100
    if arr_hh == 24:
        arr_hh = 0
        arr_shift = True
    else:
        arr_shift = False

    # create datetime variable for departure
    dt_dep = datetime(yyyy,MM,dd,dep_hh,dep_mm)
    if dep_shift:
        dt_dep += timedelta(days=1)
    # apply local time zone
    dep_local = pytz.timezone(dep_tz).localize(dt_dep)
    # convert to UTC
    dep_utc = dep_local.astimezone(pytz.utc)

    # create datetime variable for arrival
    dt_arr = datetime(yyyy,MM,dd,arr_hh,arr_mm)
    if arr_shift:
        dt_arr += timedelta(days=1)
    # apply local time zone
    arr_local = pytz.timezone(arr_tz).localize(dt_arr)
    # convert to UTC
    arr_utc = arr_local.astimezone(pytz.utc)

    if dep_utc > arr_utc:
        arr_utc += timedelta(days=1)

    # # sanity check
    # arr_utc_SC = dep_utc + timedelta(minutes=flight_dur)

    dt_format = "%Y-%m-%dT%H:%M:%S"

    # return UTC datetime, cast to string
    # return (dep_utc.strftime(dt_format), arr_utc.strftime(dt_format), arr_utc_SC.strftime(dt_format))
    return (dep_utc.strftime(dt_format), arr_utc.strftime(dt_format))

schema = StructType([
    StructField("dep_datetime", StringType(), False),
    StructField("arr_datetime", StringType(), False),
])

dt_udf = udf(to_utc, schema)

out = df.withColumn('processed', dt_udf(col("FL_DATE"), col("CRS_DEP_TIME"), col("CRS_ARR_TIME"), col("origin_timezone"), col("dest_timezone"), col("CRS_ELAPSED_TIME"))).cache()

cols = [c for c in out.columns if c != "processed"]
cols += ["processed.dep_datetime","processed.arr_datetime"]
out = out.select(cols)

display(out)

# Weather

## Nulls

### Step 1: Parse METAR reports

In [0]:


df_interpolate = (
    df_
    .withColumn(
        "origin_HourlyWindSpeed",
        F.when(
            F.col("origin_HourlyWindSpeed").isNull(),
            # Extract sustained wind speed from METAR groups
            F.regexp_extract(
                F.col("origin_REM"),
                r'\b(\d{3})(\d{2,3})(?:G(\d{2,3}))?KT\b',  # Regex pattern
                2  # Capture group for sustained wind speed
            ).cast("int")
        ).otherwise(F.col("origin_HourlyWindSpeed"))
    )
    .withColumn(
        "origin_HourlyWindGustSpeed",
        F.when(
            F.col("origin_HourlyWindGustSpeed").isNull(),
            F.greatest(
                # Regular wind gust (G group)
                F.regexp_extract(
                    F.col("origin_REM"),
                    r'\b(\d{3})(\d{2,3})(?:G(\d{2,3}))?KT\b',
                    3
                ).cast("int"),
                # Peak wind gust (PK WND group)
                F.regexp_extract(
                    F.col("origin_REM"),
                    r'PK WND (\d{3})(\d{2,3})/(\d{4})',  # PK WND pattern
                    2  # Capture group for peak wind speed
                ).cast("int")
            )
        ).otherwise(F.col("origin_HourlyWindGustSpeed"))
    )
)


df_interpolate = (df_interpolate \
    .withColumn(
        'origin_HourlyPrecipitation',
        F.when(
            (F.col("origin_HourlyPrecipitation").isNull()) | (F.col("origin_HourlyPrecipitation") == '*'),
            (F.regexp_extract(F.col("origin_REM"), r" P(\d+)", 1).cast("int") * 0.01) # hundredths of inch kept in "remarks" section
        ).otherwise(F.col("origin_HourlyPrecipitation"))
    ) \
    .withColumn('origin_HourlyPrecipitation', F.regexp_replace('origin_HourlyPrecipitation', 'T', '0.01')) \
    .withColumn(
        'origin_HourlyPrecipitation',
        F.regexp_extract('origin_HourlyPrecipitation', r"[0-9]+(\.[0-9]+)?", 0) # Match digits
    ) \
    .withColumn('origin_HourlyPrecipitation', F.col('origin_HourlyPrecipitation').cast(DoubleType())))

In [0]:
display(df_interpolate)

In [0]:
df_interpolate=df_interpolate.withColumn("origin_HourlyWindSpeed", df_interpolate["origin_HourlyWindSpeed"].cast("int"))

In [0]:
df_interpolate.checkpoint()

### Step 2: Geohashing

In some cases, the REM reports are null or do not contain information to fill in null values. Instead, find the nearest station's weather observation that is closest in time to the current station.

TODO: 
- make sure only pulling obs before or at record's time
- avg reports across different stations in same geohash at same time index if there is a time tie

#### Encoding

In [0]:
display(df_.select('origin_LATITUDE','origin_LONGITUDE','origin_NAME'))

Hash airports to grid cells so we can pull values from nearby airports at the same time

In [0]:
def encode_geohash(precision: int):
    @pandas_udf("string")
    def encode(latitudes: pd.Series, longitudes: pd.Series) -> pd.Series:
        def safe_encode(lat, lon):
            try:
                return geohash.encode(lat, lon, precision)
            except Exception:
                return None
        return latitudes.combine(longitudes, safe_encode)
    return encode

geohash_udf = encode_geohash(precision=2)
df_interpolate = df_interpolate.withColumn('geohash', geohash_udf(F.col('origin_LATITUDE'), F.col('origin_LONGITUDE')))
display(df_interpolate)

In [0]:
df_interpolate.groupBy(F.col('geohash')).count().show()

In [0]:
display(df_interpolate.filter(F.col('origin_REGION').contains('US-IL')).select('origin_airport_name','geohash').distinct())

#### Imputation

In [0]:
def coalesce_within_geohash(
    df, 
    target_col, 
    geohash_col="geohash", 
    dt_col="two_hours_prior_depart_UTC", 
    window_hours=2 
):
    """Fill nulls using latest non-null value from same geohash within x hours prior to threshold"""
    
    # Convert hours to seconds to add to unix timestamp
    window_seconds = window_hours * 3600  
    
    window_spec = (
        Window.partitionBy(geohash_col)
              .orderBy(F.col(dt_col).cast("long"))  
              .rangeBetween(-window_seconds, 0)  #time based window
    )
    
    return df.withColumn(
        target_col,
        F.coalesce(  # Keep original value if not null
            F.col(target_col), 
            F.last(target_col, ignorenulls=True).over(window_spec)
        )
    )


In [0]:


@pandas_udf(DoubleType())
def exponential_smoothing_pandas(values: pd.Series) -> pd.Series:
    """Vectorized UDF for exponential smoothing."""
    if values.empty or not pd.api.types.is_numeric_dtype(values):
        return pd.Series([0.0] * len(values))  # Handle edge cases
    return values.ewm(alpha=0.5, ignore_na=True).mean()

def smooth_column_optimized(
    df, 
    col_name, 
    airport_col="ORIGIN", 
    dt_col="two_hours_prior_depart_UTC", 
    window_size=8
):
    """Applies forward-filling exponential smoothing MA to interpolate missing values """
    
    # 1. Cast to numeric type and filter nulls
    df = (
        df.withColumn(col_name, F.col(col_name).cast(DoubleType()))
    )
    
    # 2. Define window to collect non-null values
    window_spec = (
        Window.partitionBy(airport_col)
              .orderBy(F.col(dt_col).cast("long"))
              .rowsBetween(-window_size, -1)
    )
    
    # 3. Collect ONLY non-null values within the window
    df = df.withColumn(
        "non_null_values",
        F.collect_list(col_name).over(window_spec)
    )
    
    # 4. Apply vectorized UDF and fill nulls
    return (
        df.withColumn("smoothed", exponential_smoothing_pandas("non_null_values"))
          .withColumn(
              col_name, 
              F.coalesce(F.col(col_name), 
                         F.col("smoothed"))
          )
          .drop("non_null_values", "smoothed")
    )


In [0]:
def impute_nulls(
    df, 
    target_col, 
    geohash_col="geohash", 
    airport_col="ORIGIN", 
    dt_col="two_hours_prior_depart_UTC", #everything hash to be before this threshold
):
    """Two-step null imputation:
    1. Fill with latest value from same geohash (spatial)
    2. Fill remaining nulls with EMA of same airport (temporal)
    """
    
    # Step 1: Geohash-based imputation
    df = coalesce_within_geohash(
        df, 
        target_col=target_col,
        geohash_col=geohash_col,
        dt_col=dt_col
    )
    
    # Step 2: EMA-based imputation for remaining nulls
    df = smooth_column_optimized(
        df,
        col_name=target_col,
        airport_col=airport_col,
        dt_col=dt_col
    )
    
    return df


In [0]:
df_interpolated = df_interpolate.withColumns(
    {"origin_HourlyPrecipitation": df_interpolate["origin_HourlyPrecipitation"].cast("float"),
     "origin_HourlyWindGustSpeed": df_interpolate["origin_HourlyWindGustSpeed"].cast("float"),
     "origin_HourlyWindSpeed": df_interpolate["origin_HourlyWindSpeed"].cast("float"),
     "origin_HourlyDewPointTemperature": df_interpolate["origin_HourlyDewPointTemperature"].cast("float"),
     "origin_HourlyDryBulbTemperature": df_interpolate["origin_HourlyDryBulbTemperature"].cast("float"),
     "origin_HourlyPressureChange": df_interpolate["origin_HourlyPressureChange"].cast("float"),
     "origin_HourlyRelativeHumidity": df_interpolate["origin_HourlyRelativeHumidity"].cast("float"),
     "origin_HourlyWetBulbTemperature": df_interpolate["origin_HourlyWetBulbTemperature"].cast("float"),
     "origin_HourlyVisibility": df_interpolate["origin_HourlyVisibility"].cast("float")
     
     
     })
     
columns_to_fill = ['origin_HourlyVisibility','origin_HourlyWindSpeed','origin_HourlyDewPointTemperature','origin_HourlyDryBulbTemperature','origin_HourlyPressureChange','origin_HourlyRelativeHumidity','origin_HourlyWetBulbTemperature','origin_HourlyPrecipitation','origin_HourlyWindGustSpeed']




In [0]:

for col in columns_to_fill:
    df_interpolated = impute_nulls(df_interpolated, col)



In [0]:

null_counts = df_interpolated.filter(F.col('origin_LATITUDE').isNotNull()).select(
    [F.count(F.when(F.col(c).isNull() | F.isnan(c), c)).alias(c) for c in columns_to_fill]
)

null_counts_orig = df_interpolate.filter(F.col('origin_LATITUDE').isNotNull()).select(
    [F.count(F.when(F.col(c).isNull() | F.isnan(c), c)).alias(c) for c in columns_to_fill]
)

display(null_counts_orig.unionByName(null_counts)) #null counts where we have non null loc (obviously, null loc is impossible to fill in by geohash)

In [0]:


output_path = "dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined_1y_weather_cleaned_v2.parquet"
(
    df_interpolated.write
    .mode("overwrite")
    .parquet(output_path)
)




In [0]:
df_interpolated.count()

# Flights

## Lag features engineering

### Encoding

Create UTC-ed arrival time for lag columns

In [0]:
def to_utc(yyyymmdd, dep_hhmm, arr_hhmm, dep_tz, arr_tz, flight_dur):
    """
    Create UTC timestamp from flights table columns
    yyyymmdd = FL_DATE
    dep_hhmm = CRS_DEP_TIME
    arr_hhmm = CRS_ARR_TIME
    dep_tz = origin_timezone
    arr_tz = dest_timezone
    flight_dur = CRS_ELAPSED_TIME (for sanity check of arrival time)

    Returns UTC time stamp, (cast to string)
    """

    dep_hhmm = int(dep_hhmm)
    arr_hhmm = int(arr_hhmm)

    yyyy,MM,dd = yyyymmdd.split('-')
    yyyy = int(yyyy) # get year
    MM = int(MM) # get month
    dd = int(dd) # get day

    dep_hh = dep_hhmm//100 # get hour
    dep_mm = dep_hhmm%100 # get minute
    if dep_hh == 24:
        dep_hh = 0
        dep_shift = True
    else:
        dep_shift = False

    arr_hh = arr_hhmm//100 # get hour
    arr_mm = arr_hhmm%100
    if arr_hh == 24:
        arr_hh = 0
        arr_shift = True
    else:
        arr_shift = False

    # create datetime variable for departure
    dt_dep = datetime(yyyy,MM,dd,dep_hh,dep_mm)
    if dep_shift:
        dt_dep += timedelta(days=1)
    # apply local time zone
    dep_local = pytz.timezone(dep_tz).localize(dt_dep)
    # convert to UTC
    dep_utc = dep_local.astimezone(pytz.utc)

    # create datetime variable for arrival
    dt_arr = datetime(yyyy,MM,dd,arr_hh,arr_mm)
    if arr_shift:
        dt_arr += timedelta(days=1)
    # apply local time zone
    arr_local = pytz.timezone(arr_tz).localize(dt_arr)
    # convert to UTC
    arr_utc = arr_local.astimezone(pytz.utc)

    if dep_utc > arr_utc:
        arr_utc += timedelta(days=1)

    # # sanity check
    # arr_utc_SC = dep_utc + timedelta(minutes=flight_dur)

    dt_format = "%Y-%m-%dT%H:%M:%S"

    # return UTC datetime, cast to string
    # return (dep_utc.strftime(dt_format), arr_utc.strftime(dt_format), arr_utc_SC.strftime(dt_format))
    return (dep_utc.strftime(dt_format), arr_utc.strftime(dt_format))

schema = StructType([
    StructField("dep_datetime", StringType(), False),
    StructField("arr_datetime", StringType(), False),
])

dt_udf = udf(to_utc, schema)

out = df_interpolated.withColumn('processed', 
                                 dt_udf(F.col("FL_DATE"), 
                                        F.col("CRS_DEP_TIME"), 
                                        F.col("CRS_ARR_TIME"), 
                                        F.col("origin_timezone"), 
                                        F.col("dest_timezone"), 
                                        F.col("CRS_ELAPSED_TIME"))
                                 ).cache()

cols = [c for c in out.columns if c != "processed"]
cols += ["processed.dep_datetime","processed.arr_datetime"]
out = out.select(cols)

display(out)

Fill in CRS (i.e., estimated) flight time:

In [0]:


out = out.withColumn("arr_datetime", F.col("arr_datetime").cast("timestamp"))
out = out.withColumn(
    "CRS_ELAPSED_TIME",
    F.when(
        F.col("CRS_ELAPSED_TIME").isNull(),
        (F.col("arr_datetime") - F.col("sched_depart_utc")).cast('int')
    ).otherwise(F.col("CRS_ELAPSED_TIME"))
)
display(out)

In [0]:
out.filter(F.col('CRS_ELAPSED_TIME').isNull()).count() #sanity check

### Create features

In [0]:


WhenConditions = (F.col("ORIGIN") == F.col("priorflight_dest")) & (F.col("priorflight_sched_deptime") >= F.col("twentysix_hours_prior_depart_UTC"))

def add_lags_optimized(df):
    # Define windows once
    aircraft_window = Window.partitionBy("TAIL_NUM").orderBy("sched_depart_utc")
    route_window = Window.partitionBy("ORIGIN", "DEST").orderBy("sched_depart_utc").rowsBetween(-100, -1)

    # Precompute all lagged columns in single qpass
    lagged_cols = [
        F.lag("ORIGIN").over(aircraft_window).alias("priorflight_origin"),
        F.lag("DEST").over(aircraft_window).alias("priorflight_dest"),
        F.lag("CANCELLED").over(aircraft_window).alias("priorflight_cancelled_true"),
        F.lag("sched_depart_utc").over(aircraft_window).alias("priorflight_sched_deptime"),
        F.lag("CRS_ELAPSED_TIME").over(aircraft_window).alias("priorflight_elapsed_time_calc_raw"),
        F.lag("DEP_DELAY").over(aircraft_window).alias("priorflight_depdelay_true_raw"),
        F.lag("arr_datetime").over(aircraft_window).alias("priorflight_arr_time_true")
    ]

    # Base transformations
    base_df = (df
        .withColumn("twentysix_hours_prior_depart_UTC", 
                   (F.col("two_hours_prior_depart_UTC") - F.expr("INTERVAL 24 HOURS")).cast("timestamp"))
        .select("*", *lagged_cols)
    )

    # Precompute common conditions
    WhenConditions = (
        (F.col("ORIGIN") == F.col("priorflight_dest")) & 
        (F.col("priorflight_sched_deptime") >= F.col("twentysix_hours_prior_depart_UTC"))
    )
    valid_prior = WhenConditions & (F.col("priorflight_cancelled_true") != 1)

    # Core calculations
    result_df = (base_df
        .withColumn("priorflight_elapsed_time_calc",
            F.when(valid_prior,
                F.expr("INTERVAL 1 MINUTE") * F.col("priorflight_elapsed_time_calc_raw")
            )
        )
        .withColumn("priorflight_depdelay_true",
            F.when(valid_prior, F.col("priorflight_depdelay_true_raw"))
        )
        .withColumn("priorflight_deptime_true",
            F.when(valid_prior,
                F.col("priorflight_sched_deptime") + 
                (F.expr("INTERVAL 1 MINUTE") * F.col("priorflight_depdelay_true"))
            )
        )
        .withColumn("priorflight_isdeparted",
            F.when(
                (F.col("priorflight_deptime_true") <= F.col("two_hours_prior_depart_UTC")) &
                valid_prior, 1
            ).otherwise(0)
        )
        .withColumn("priorflight_depdelay_calc",
            F.when(
                (F.col("priorflight_deptime_true") <= F.col("two_hours_prior_depart_UTC")) & valid_prior,
                F.col("priorflight_depdelay_true")
            ).when(
                (F.col("priorflight_sched_deptime") <= F.col("two_hours_prior_depart_UTC")) &
                (F.col("priorflight_deptime_true") > F.col("two_hours_prior_depart_UTC")) &
                valid_prior,
                (F.col("two_hours_prior_depart_UTC").cast('long') - 
                 F.col("priorflight_sched_deptime").cast('long')) / 60
            ).otherwise(F.lit(0.0)) #if not enough info, assume all is well - in line with other logic
        )
        .withColumn("priorflight_deptime_calc",
            F.col("priorflight_sched_deptime") + 
            (F.expr("INTERVAL 1 MINUTE") * F.col("priorflight_depdelay_calc"))
        )
        .withColumn("priorflight_isdelayed_calc",
            F.when(
                (F.col("priorflight_depdelay_calc") >= 15) | 
                (F.col('priorflight_cancelled_true') == 1), 1
            ).otherwise(0)
        )
        .withColumn("elapsed_time_true",
            F.when(valid_prior,
                (F.col("AIR_TIME") + F.col("TAXI_IN") + F.col("TAXI_OUT")).cast("int")
            )
        )
        .withColumn("arr_time_true",
            F.col("arr_datetime").cast("timestamp") +
            (F.expr("INTERVAL 1 MINUTE") * F.col("ARR_DELAY"))
        )
        .withColumn("priorflight_isarrived_calc",
            F.when(
                (F.col("priorflight_arr_time_true") <= F.col("two_hours_prior_depart_UTC")) &
                valid_prior, 1
            ).otherwise(0)
        )
        .withColumn("priorflight_arr_time_calc",
            F.when(
                F.col("priorflight_isarrived_calc") == 1,
                F.col("priorflight_arr_time_true")
            ).when(
                (F.col("priorflight_isarrived_calc") == 0) &
                (F.col("priorflight_deptime_true") <= F.col("two_hours_prior_depart_UTC")), 
                F.col("priorflight_deptime_true") + F.col("priorflight_elapsed_time_calc")
            ).otherwise(
                F.col("priorflight_deptime_calc") + F.col("priorflight_elapsed_time_calc")
            )
        )
        .withColumn("turnaround_time_calc",
            F.when(valid_prior,
                ((F.col("sched_depart_utc").cast("long") - 
                  F.col("priorflight_arr_time_calc").cast("long")) / 60).cast("double")
            )
        )
        # Edge case handling
        .withColumn("turnaround_time_calc",
            F.when(
                (~valid_prior) | (F.col("priorflight_cancelled_true") == 1),
                F.last("turnaround_time_calc", ignorenulls=True).over(route_window)
            ).otherwise(F.col("turnaround_time_calc"))
        )
        .withColumn("priorflight_depdelay_calc",
            F.when(
                (~valid_prior) | (F.col("priorflight_cancelled_true") == 1),
                F.last("priorflight_depdelay_calc", ignorenulls=True).over(route_window)
            ).otherwise(F.col("priorflight_depdelay_calc"))
        )
    ).cache()

    return result_df

# Execute pipeline
result = add_lags_optimized(out)
display(result)


In [0]:
result.count() #sanity check

In [0]:
result.checkpoint()

In [0]:
print(f"Delay indicator null count: {result.filter(F.col('priorflight_isdelayed_calc').isNull()).count()}")

print(f"Departed indicator null count: {result.filter(F.col('priorflight_isdeparted').isNull()).count()}")

print(f"Arrival indicator null count: {result.filter(F.col('priorflight_isarrived_calc').isNull()).count()}")

print(f"Est. Turnaround Time Null Count: {result.filter(F.col('turnaround_time_calc').isNull()).count()}")

print(f"Delay estimate null count: {result.filter(F.col('priorflight_isdelayed_calc').isNull()).count()}")


If turnaround time is unable to be calculated based on the record features, try to estimate based on prior records

In [0]:
result_imputed = smooth_column_optimized(
    df=result, 
    col_name= "turnaround_time_calc", 
    airport_col="ORIGIN", 
    dt_col="two_hours_prior_depart_UTC", 
    window_size=4)


In [0]:
print(f"Est. Turnaround Time Null Count: {result_imputed.filter(F.col('turnaround_time_calc').isNull()).count()}")


In [0]:
df = result_imputed.withColumn("origin_HourlyWindSpeed", F.when(F.col("origin_HourlyWindSpeed") == 2237, 0).otherwise(F.col("origin_HourlyWindSpeed")))

In [0]:


output_path = "dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined_1y_cleaned_engineered.parquet"
(
    df.write
    .mode("overwrite")
    .parquet(output_path)
)

In [0]:
df.checkpoint()

# EDA

### Setup

In [0]:
output_path = "dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined_1y_cleaned_engineered.parquet"

df=spark.read.parquet(output_path)

In [0]:
import pyspark.sql.functions as F
from pyspark.sql.functions import col,isnan, when, count, concat_ws, countDistinct, collect_set, rank, window, avg, hour, udf, isnan, pandas_udf
import matplotlib.pyplot as plt
import pandas as pd
import re
import numpy as np
from pyspark.sql import types
from pyspark.sql.types import *
from pyspark.sql import Window
import seaborn as sns

In [0]:
df=df.withColumn('outcome', F.when((F.col('DEP_DELAY')>=15) | (F.col('CANCELLED')==1), 1).otherwise(0))
              

## spot check - old

In [0]:
display(df.filter(F.col('origin_HourlyWindSpeed') > 2000)) #was also present in the original dataset

In [0]:
display(df.filter(F.col('origin_HourlyWindSpeed') > 2000)) #does not look accurate and many of these say 0KT so  must be incorrect

In [0]:
df = df.withColumn("origin_HourlyWindSpeed", F.when(F.col("origin_HourlyWindSpeed") == 2237, 0).otherwise(F.col("origin_HourlyWindSpeed")))

In [0]:
df=df.withColumn('DEP_DEL15', F.when(col('DEP_DELAY') >= 15, 1).otherwise(0))

## Basic plots with Outcome

In [0]:


# Convert wind speed data to pandas
wind_df = df.filter(F.col('origin_HourlyWindSpeed').isNotNull()) \
    .withColumn("outcome", F.col("outcome").cast("string")) \
    .select("origin_HourlyWindSpeed", "outcome") \
    .toPandas()

wind_gust_df = df.filter(F.col('origin_HourlyWindGustSpeed').isNotNull()) \
    .withColumn("outcome", F.col("outcome").cast("string")) \
    .select("origin_HourlyWindGustSpeed", "outcome") \
    .toPandas()

# Create subplots
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Boxplot for Hourly Wind Speed
wind_df.boxplot(column="origin_HourlyWindSpeed", by="outcome", ax=axes[0])
axes[0].set_xlabel("Departure Delay (0 = No Delay, 1 = Delayed)")
axes[0].set_ylabel("Hourly Wind Speed")
axes[0].set_title("Hourly Wind Speed by Departure Delay")

# Boxplot for Hourly Wind Gust Speed
wind_gust_df.boxplot(column="origin_HourlyWindGustSpeed", by="outcome", ax=axes[1])
axes[1].set_xlabel("Departure Delay (0 = No Delay, 1 = Delayed)")
axes[1].set_ylabel("Hourly Wind Gust Speed")
axes[1].set_title("Hourly Wind Gust Speed by Departure Delay")

# Adjust layout
plt.suptitle("")
plt.tight_layout()
plt.show()


In [0]:
#look at nonzero wind speeds

# Convert wind speed data to pandas
wind_df = df.filter(F.col('origin_HourlyWindSpeed').isNotNull()) \
    .withColumn("outcome", col("outcome").cast("string")) \
    .select("origin_HourlyWindSpeed", "outcome") \
    .toPandas()

wind_gust_df = df.filter(F.col('origin_HourlyWindGustSpeed').isNotNull()).filter(F.col('origin_HourlyWindGustSpeed') > 0) \
    .withColumn("outcome", col("outcome").cast("string")) \
    .select("origin_HourlyWindGustSpeed", "outcome") \
    .toPandas()

# Create subplots
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Boxplot for Hourly Wind Speed
wind_df.boxplot(column="origin_HourlyWindSpeed", by="outcome", ax=axes[0])
axes[0].set_xlabel("Departure Delay (0 = No Delay, 1 = Delayed)")
axes[0].set_ylabel("Nonzero Hourly Wind Speed")
axes[0].set_title("Nonzero Hourly Wind Speed by Departure Delay")

# Boxplot for Hourly Wind Gust Speed
wind_gust_df.boxplot(column="origin_HourlyWindGustSpeed", by="outcome", ax=axes[1])
axes[1].set_xlabel("Departure Delay (0 = No Delay, 1 = Delayed)")
axes[1].set_ylabel("Nonzero Hourly Wind Gust Speed")
axes[1].set_title("Nonzero Hourly Wind Gust Speed by Departure Delay")

# Adjust layout
plt.suptitle("")
plt.tight_layout()
plt.show()


In [0]:




drybulb_df = df.filter(F.col('origin_HourlyDryBulbTemperature').isNotNull()) \
    .withColumn("outcome", col("outcome").cast("string")) \
    .select("origin_HourlyDryBulbTemperature", "outcome") \
    .toPandas()

wetbulb_df = df.filter(F.col('origin_HourlyWetBulbTemperature').isNotNull()) \
    .withColumn("outcome", col("outcome").cast("string")) \
    .select("origin_HourlyWetBulbTemperature", "outcome") \
    .toPandas()

# Create subplots
fig, axes = plt.subplots(1, 2, figsize=(15, 6))


drybulb_df.boxplot(column="origin_HourlyDryBulbTemperature", by="outcome", ax=axes[0])
axes[0].set_xlabel("Departure Delay (0 = No Delay, 1 = Delayed)")
axes[0].set_ylabel("Hourly Dry Bulb Temp")
axes[0].set_title("Hourly Dry Bulb by Departure Delay")


wetbulb_df.boxplot(column="origin_HourlyWetBulbTemperature", by="outcome", ax=axes[1])
axes[1].set_xlabel("Departure Delay (0 = No Delay, 1 = Delayed)")
axes[1].set_ylabel("Hourly Wet Bulb")
axes[1].set_title("Hourly Wet Bulb Temp by Departure Delay")

# Adjust layout
plt.suptitle("")
plt.tight_layout()
plt.show()

#online it says that wet bulb is more important for flights maybe but distributions look pretty similar for both groups for both temp types


In [0]:




vis_df = df.filter(F.col('origin_HourlyVisibility').isNotNull()) \
    .withColumn("outcome", col("outcome").cast("string")) \
    .select("origin_HourlyVisibility", "outcome") \
    .toPandas()

pressure_df = df.filter(F.col('origin_HourlyPressureChange').isNotNull()) \
    .withColumn("outcome", col("outcome").cast("string")) \
    .select("origin_HourlyPressureChange", "outcome") \
    .toPandas()

# Create subplots
fig, axes = plt.subplots(1, 2, figsize=(15, 6))


vis_df.boxplot(column="origin_HourlyVisibility", by="outcome", ax=axes[0])
axes[0].set_xlabel("Departure Delay (0 = No Delay, 1 = Delayed)")
axes[0].set_ylabel("Hourly Visibility")
axes[0].set_title("Hourly Visibility by Departure Delay")


pressure_df.boxplot(column="origin_HourlyPressureChange", by="outcome", ax=axes[1])
axes[1].set_xlabel("Departure Delay (0 = No Delay, 1 = Delayed)")
axes[1].set_ylabel("Hourly Pressure Tendency")
axes[1].set_title("Hourly Pressure Tendency by Departure Delay")

# Adjust layout
plt.suptitle("")
plt.tight_layout()
plt.show()



In [0]:
#nice version:

new_col_names = {
    "origin_HourlyDryBulbTemperature": "Dry Bulb Temp",
    "origin_HourlyWetBulbTemperature": "Wet Bulb Temp",
    "origin_HourlyWindSpeed": "Wind Speed",
    "origin_HourlyPrecipitation": "Precipitation",
    "origin_HourlyVisibility": "Visibility",
    "origin_HourlyWindGustSpeed": "Wind Gust",
    "origin_HourlyPressureChange": "Pressure Δ",
    "origin_HourlyRelativeHumidity": "Humidity",
    "outcome": "Departure Delay"
}

# Prepare data with renamed columns
heatmap_df = df.select(list(new_col_names.keys())).toPandas().rename(columns=new_col_names)
corr_matrix = heatmap_df.corr('spearman')
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))

# Create customized heatmap
plt.figure(figsize=(10, 8))
ax = sns.heatmap(
    corr_matrix,
    mask=mask,
    annot=True,
    cmap="coolwarm",
    vmin=-1,
    vmax=1,
    fmt=".2f",  # Format annotations to 2 decimals
    linewidths=.5,
    cbar_kws={"shrink": 0.8}
)

# Rotate and align ticks
ax.tick_params(axis='x', labelrotation=45, labelsize=10)
ax.tick_params(axis='y', labelrotation=0, labelsize=10)
ax.set_xticklabels(ax.get_xticklabels(), ha='right', rotation=45)
ax.set_yticklabels(ax.get_yticklabels(), va='center')

# Add title and adjust layout
plt.title("Weather Features Spearman Correlation Heatmap", pad=20, fontsize=14)
plt.tight_layout()
plt.show()


In [0]:
df.columns

In [0]:
#nice version:

cols = ["priorflight_elapsed_time_calc","priorflight_depdelay_calc", "turnaround_time_calc",
        "priorflight_isarrived_calc", "priorflight_isdeparted", "priorflight_isdelayed_calc", "priorflight_cancelled_true"]

new_col_names = {
    "priorflight_elapsed_time_calc": "Prior Flight Air Time",
    "priorflight_depdelay_calc": "Prior Flight Delay Time",
    "turnaround_time_calc": "Turnaround Time",
    "priorflight_isarrived_calc": "Prior Flight Arrived",
    "priorflight_isdeparted": "Prior Flight Departed",
    "priorflight_isdelayed_calc": "Prior Flight Delayed",
    "priorflight_cancelled_true": "Prior Flight Cancelled",
    "outcome": "Current Flight Delayed"
}

# Prepare data with renamed columns
heatmap_df = df.select(list(new_col_names.keys())).toPandas().rename(columns=new_col_names)
corr_matrix = heatmap_df.corr('spearman')
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))

# Create customized heatmap
plt.figure(figsize=(10, 8))
ax = sns.heatmap(
    corr_matrix,
    mask=mask,
    annot=True,
    cmap="coolwarm",
    vmin=-1,
    vmax=1,
    fmt=".2f",  # Format annotations to 2 decimals
    linewidths=.5,
    cbar_kws={"shrink": 0.8}
)

# Rotate and align ticks
ax.tick_params(axis='x', labelrotation=45, labelsize=10)
ax.tick_params(axis='y', labelrotation=0, labelsize=10)
ax.set_xticklabels(ax.get_xticklabels(), ha='right', rotation=45)
ax.set_yticklabels(ax.get_yticklabels(), va='center')

# Add title and adjust layout
plt.title("Prior Flight Features: Spearman Correlation Heatmap", pad=20, fontsize=14)
plt.tight_layout()
plt.show()


In [0]:




depdel_calc_df = df.filter(F.col('priorflight_depdelay_calc').isNotNull()) \
    .withColumn("outcome", col("outcome").cast("string")) \
    .select("priorflight_depdelay_calc", "outcome") \
    .toPandas()

turnaround_calc_df = df.filter(F.col('turnaround_time_calc').isNotNull()) \
    .withColumn("outcome", col("outcome").cast("string")) \
    .select("turnaround_time_calc", "outcome") \
    .toPandas()

# Create subplots
fig, axes = plt.subplots(1, 2, figsize=(15, 6))


depdel_calc_df.boxplot(column="priorflight_depdelay_calc", by="outcome", ax=axes[0])
axes[0].set_xlabel("Departure Delay (0 = No Delay, 1 = Delayed)")
axes[0].set_ylabel("Estimated Prior Flight Dep. Delay")
axes[0].set_title("Current Flight Departure Delay by Prior Departure Delay")


turnaround_calc_df.boxplot(column="turnaround_time_calc", by="outcome", ax=axes[1])
axes[1].set_xlabel("Departure Delay (0 = No Delay, 1 = Delayed)")
axes[1].set_ylabel("Estimated Turnaround Time")
axes[1].set_title("Current Flight Departure Delay by Turnaround Time")


# Adjust layout
plt.suptitle("")
plt.tight_layout()
plt.show()



## Geohash exploration for Mapping

In [0]:
display(df.groupBy('geohash').agg(
   F.avg("DEP_DELAY").alias("AVG_DEP_DELAY")
).orderBy("AVG_DEP_DELAY", ascending=False))

In [0]:
df.filter(F.col('geohash')=='9p').select('origin_airport_name').distinct().show()
#longest avg delay times

In [0]:
df.filter(F.col('geohash')=='bs').select('origin_airport_name').distinct().show()
#shortest avg delay times 

In [0]:
paths_df = df.withColumn(
    "FLIGHT_PATH", 
    concat_ws(" -> ", "origin_airport_name", "dest_airport_name")
).groupBy("FLIGHT_PATH").agg(
   F.avg("DEP_DELAY").alias("AVG_DEP_DELAY")
).orderBy("AVG_DEP_DELAY", ascending=False)

In [0]:
display(paths_df)

In [0]:
display(df.filter(F.col('DEP_DELAY')>720))

In [0]:
display(df_.filter(F.col('origin_airport_name')=='EGLIN AFB AIRPORT').filter(F.col('dest_airport_name')=='SARASOTA/BRADENTON INTL AP'))

In [0]:
display(df.groupBy('origin_airport_name').agg(
   F.avg("DEP_DELAY").alias("AVG_DEP_DELAY")
).orderBy("AVG_DEP_DELAY", ascending=False))

In [0]:
display(df.groupBy('dest_airport_name').agg(
   F.avg("DEP_DELAY").alias("AVG_DEP_DELAY")
).orderBy("AVG_DEP_DELAY", ascending=False))

In [0]:
df.filter(F.col('TAIL_NUM').isNull()).count()

In [0]:
display(df_.filter(F.col('origin_station_lat').isNull()))

In [0]:
display(df.filter(F.col('origin_airport_name')=='EGLIN AFB AIRPORT').agg(F.avg("DEP_DELAY")))

In [0]:
df

In [0]:
display(df_.filter(F.col('origin_airport_name')=='EGLIN AFB AIRPORT').orderBy(F.col('DEP_DELAY').desc()))

In [0]:
display(df.groupBy('origin_type').agg(
   F.avg("DEP_DELAY").alias("AVG_DEP_DELAY")
).orderBy("AVG_DEP_DELAY", ascending=False))

In [0]:
# add visibility at previous airport

df=df.withColumn(
    "prior_origin_visibility",
    F.lag("origin_HourlyVisibility").over(
        Window.partitionBy("TAIL_NUM")
              .orderBy("sched_depart_utc")
    )
)

df=df.withColumn(
    "prior_origin_precipitation",
    F.lag("origin_HourlyPrecipitation").over(
        Window.partitionBy("TAIL_NUM")
              .orderBy("sched_depart_utc")
    )
)

In [0]:
display(df) #sanity check

In [0]:


# Select relevant columns and convert to pandas
cols = ["origin_HourlyDryBulbTemperature", "origin_HourlyWetBulbTemperature", "origin_HourlyWindSpeed", "origin_HourlyPrecipitation", "origin_HourlyVisibility", "origin_HourlyWindGustSpeed","prior_origin_visibility","prior_origin_precipitation"]
heatmap_df = df.select(cols + ["outcome"]).toPandas()

# Compute correlation matrix
corr_matrix = heatmap_df.corr('spearman')
mask = np.triu(np.ones_like(corr_matrix, dtype=bool)) #lower triangle only

# Plot heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(corr_matrix, mask=mask, annot=True, cmap="coolwarm", vmin=-1, vmax=1)
plt.title("Correlation Heatmap")
plt.show()

In [0]:
df.columns

In [0]:
# Select relevant columns and convert to pandas
cols = ["priorflight_depdelay_final","priorflight_crs_elapsed_time", "priorflight_distance", "priorflight_isdelayed", "priorflight_arrived"]
flights_heatmap = df.select(cols + ["outcome"]).toPandas()

# Compute correlation matrix
corr_matrix = flights_heatmap.corr('spearman')
mask = np.triu(np.ones_like(corr_matrix, dtype=bool)) #lower triangle only


# Plot heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(corr_matrix, mask=mask, annot=True, cmap="coolwarm", vmin=-1, vmax=1, 
            xticklabels=["Prior Flight Dep Delay", "Prior Flight CRS Elapsed Time", "Prior Flight Distance", "Prior Flight Is Delayed", "Prior Flight Arrived", "Dep Delay"],
            yticklabels=["Prior Flight Dep Delay", "Prior Flight CRS Elapsed Time", "Prior Flight Distance", "Prior Flight Is Delayed", "Prior Flight Arrived", "Dep Delay"])
plt.title("Correlation Heatmap")
plt.xticks(rotation=75)
plt.show()

## Maps

In [0]:
!pip install python-geohash
!pip install geopandas

In [0]:
import plotly.express as px
import plotly.graph_objects as go
import geohash
import geopandas as gpd
from shapely.geometry import Polygon
import json

### by geohash

In [0]:
ghdf=df.groupBy('geohash').agg(F.avg('DEP_DELAY').alias('AVG_DELAY')).toPandas()
ghdf.dropna(subset=['geohash'],inplace=True)

In [0]:
def geohash_to_polygon(gh):
    bbox = geohash.bbox(gh)  # Get bounding box for the geohash
    return Polygon([
        (bbox['w'], bbox['s']),  # Southwest corner
        (bbox['e'], bbox['s']),  # Southeast corner
        (bbox['e'], bbox['n']),  # Northeast corner
        (bbox['w'], bbox['n']),  # Northwest corner
        (bbox['w'], bbox['s'])   # Close the polygon
    ])

# Convert your DataFrame into a GeoDataFrame with polygons
ghdf["geometry"] = ghdf["geohash"].apply(geohash_to_polygon)
gdf = gpd.GeoDataFrame(ghdf, geometry="geometry")
gdf.set_crs(epsg=4326, inplace=True)

In [0]:
geojson = json.loads(gdf.set_index('geohash').to_json())  # Set geohash as index first

In [0]:
fig = px.choropleth_mapbox(
    ghdf,
    geojson=geojson,
    locations='geohash',  # Matches GeoJSON feature IDs
    color='AVG_DELAY',
    color_continuous_scale="YlOrRd",
    range_color=(gdf['AVG_DELAY'].min(), gdf['AVG_DELAY'].max()),
    mapbox_style="open-street-map",
    zoom=3,
    center={"lat": 37.6, "lon": -95.6},  # Center on US
    opacity=0.3,
    labels={'AVG_DELAY': 'Avg Delay (minutes)'}
)

fig.update_traces(marker_line_width=0)

# Customize layout
fig.update_layout(
    margin={"r":20,"t":40,"l":20,"b":20},
    coloraxis_colorbar={
        'title': 'Delay',
        'thickness': 20,
        'len': 0.5
    }
)

fig.show()

### by origin airport

In [0]:
adf=df.groupBy('origin_airport_name').agg(F.avg('DEP_DELAY').alias('AVG_DELAY')).filter(F.col('AVG_DELAY')>=0).join(df.select('origin_airport_name','origin_airport_lat','origin_airport_lon'),on='origin_airport_name',how='left_outer').toPandas()

In [0]:
adf=adf[adf['AVG_DELAY']>=0].drop_duplicates(subset=['origin_airport_name'])
adf_=df.groupBy('origin_airport_name').count().select('origin_airport_name','count').distinct().toPandas()
adf=pd.merge(adf,adf_,on='origin_airport_name',how='left')

In [0]:
fig = go.Figure()

fig.add_trace(go.Scattergeo(
    locationmode='USA-states',       # Use USA state locations
    lon=adf['origin_airport_lon'],          # Longitude of airports
    lat=adf['origin_airport_lat'],          # Latitude of airports
    text=adf['origin_airport_name'] + '<br>Avg Delay: ' + adf['AVG_DELAY'].astype(str) + ' mins',
    marker=dict(
        size=adf['count'].apply(lambda x: x/adf['count'].sum())*500,  # Marker size based on relative # flights at that airport
        sizemin=2, #otherwise we can't see the alaska ones
        color=adf['AVG_DELAY'], # Marker color based on avg delay
        colorscale='YlOrRd', 
        cmin=adf['AVG_DELAY'].min(),
        cmax=adf['AVG_DELAY'].max(),
        colorbar=dict(
            title="Avg Delay (mins)"
        ),
        opacity=0.8,
        line=dict(width=.3)  # border to help visibility as well
    )
))

# Update layout for better visualization
fig.update_layout(
    title='Average Flight Delays by Airport',
    geo=dict(
        showland=True,
        landcolor="teal",
        subunitwidth=1,
        countrywidth=1,
        countrycolor="white",
        showlakes=False,
        bgcolor="lightblue",  # Set map background color to blue
        subunitcolor="white"  # Draw state borders
    ),
    margin={"r":0,"t":50,"l":0,"b":0}
)

fig.show()

### by previous airport

In [0]:
pdf=df.groupBy('priorflight_origin').agg(F.avg('DEP_DELAY').alias('AVG_DEP_DELAY')).filter(F.col('AVG_DEP_DELAY')>=0)
pdf=pdf.toPandas().drop_duplicates(subset=['priorflight_origin'])


In [0]:
priors=df.select('priorflight_origin').distinct().withColumnRenamed('priorflight_origin', 'origin').join(df.select('origin', 'origin_airport_lat', 'origin_airport_lon'),on='origin',how='left')


In [0]:
priors=priors.drop_duplicates(subset=['origin']).toPandas()

In [0]:
pdf=pd.merge(pdf, priors, left_on='priorflight_origin', right_on='origin', how='left')

In [0]:
pdf_=df.groupBy('priorflight_origin').count().select('priorflight_origin','count').distinct().toPandas()
pdf=pd.merge(pdf,pdf_,on='priorflight_origin',how='left')

In [0]:
fig = go.Figure()

fig.add_trace(go.Scattergeo(
    locationmode='USA-states',       # Use USA state locations
    lon=pdf['origin_airport_lon'],          # Longitude of airports
    lat=pdf['origin_airport_lat'],          # Latitude of airports
    text=pdf['priorflight_origin'] + '<br>Avg Delay: ' + pdf['AVG_DEP_DELAY'].astype(str) + ' mins',
    marker=dict(
        size=pdf['count'].apply(lambda x: x/pdf['count'].sum())*500,  # Marker size based on relative # flights at that airport
        sizemin=2, #otherwise we can't see the alaska ones
        color=pdf['AVG_DEP_DELAY'], # Marker color based on avg delay
        colorscale='YlOrRd', 
        cmin=pdf['AVG_DEP_DELAY'].min(),
        cmax=pdf['AVG_DEP_DELAY'].max(),
        colorbar=dict(
            title="Avg Delay (mins)"
        ),
        opacity=0.8,
        line=dict(width=.3)  # border to help visibility as well
    )
))

# Update layout for better visualization
fig.update_layout(
    title='Average Flight Delays by Airport',
    geo=dict(
        showland=True,
        landcolor="teal",
        subunitwidth=1,
        countrywidth=1,
        countrycolor="white",
        showlakes=False,
        bgcolor="lightblue",  # Set map background color to blue
        subunitcolor="white"  # Draw state borders
    ),
    margin={"r":0,"t":50,"l":0,"b":0}
)

fig.show()

In [0]:
import plotly.graph_objects as go

# Create the figure
fig = go.Figure()

# Trace 1: Arriving From
fig.add_trace(go.Scattergeo(
    locationmode='USA-states',
    lon=pdf['origin_airport_lon'],
    lat=pdf['origin_airport_lat'],
    text=pdf['priorflight_origin'] + '<br>Avg Delay: ' + pdf['AVG_DEP_DELAY'].astype(str) + ' mins',
    marker=dict(
        size=pdf['count'].apply(lambda x: x / pdf['count'].sum()) * 500,
        sizemin=2,
        color=pdf['AVG_DEP_DELAY'],
        colorscale='YlOrRd',
        cmin=pdf['AVG_DEP_DELAY'].min(),
        cmax=pdf['AVG_DEP_DELAY'].max(),
        colorbar=dict(
            title="Arrival Delay (mins)",
            x=0.85,  # Positioned to the left of Departure colorbar
            y=0.5,
            len=0.4
        ),
        opacity=0.3,
        line=dict(width=0.3)
    ),
    name="Arriving From"
))

# Trace 2: Departing Delayed From
fig.add_trace(go.Scattergeo(
    locationmode='USA-states',
    lon=adf['origin_airport_lon'],
    lat=adf['origin_airport_lat'],
    text=adf['origin_airport_name'] + '<br>Avg Delay: ' + adf['AVG_DELAY'].astype(str) + ' mins',
    marker=dict(
        size=adf['count'].apply(lambda x: x / adf['count'].sum()) * 500,
        sizemin=2,
        color=adf['AVG_DELAY'],
        colorscale='Viridis',
        cmin=adf['AVG_DELAY'].min(),
        cmax=adf['AVG_DELAY'].max(),
        colorbar=dict(
            title="Departure Delay (mins)",
            x=1.0,  # Positioned to the right of Arrival colorbar
            y=0.5,
            len=0.4
        ),
        opacity=0.3,
        line=dict(width=0.3)
    ),
    name="Departing Delayed From"
))

# Update layout with margin adjustments
fig.update_layout(
    title='Average Flight Delays by Airport (Arriving vs Departing)',
    geo=dict(
        showland=True,
        landcolor="teal",
        subunitcolor="white",
        bgcolor="lightblue"
    ),
    margin={"r":150,"t":50,"l":0,"b":0}  # Increased right margin for colorbars
)


# Duplicates

In [0]:
df=spark.read.parquet("dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined_1y_weather_cleaned_combo_pfd.parquet")

In [0]:
dups=df.groupBy('sched_depart_utc','TAIL_NUM').count().filter(F.col('count')>1)

In [0]:
display(dups.filter(F.col('TAIL_NUM').isNotNull()))

In [0]:
dups_info=dups.join(df,on=['sched_depart_utc','TAIL_NUM'], how='left')

In [0]:
dups_info.checkpoint()

In [0]:
display(dups_info.filter(F.col('TAIL_NUM').isNotNull()).orderBy('TAIL_NUM','sched_depart_utc'))

In [0]:
display(dups_info.orderBy('sched_depart_utc','TAIL_NUM'))

In [0]:
dups_info.count()

In [0]:
dups_info.filter(F.col('CANCELLED')==0).filter(F.col('DEP_DELAY').isNull()).count() #they're all either delayed or cancelled flights

# Sandbox

## Optim - exp smoothing udf

In [0]:
df_interpolate = df_interpolate.withColumns(
    {"origin_HourlyPrecipitation": df_interpolate["origin_HourlyPrecipitation"].cast("float"),
     "origin_HourlyWindGustSpeed": df_interpolate["origin_HourlyWindGustSpeed"].cast("float"),
     "origin_HourlyWindSpeed": df_interpolate["origin_HourlyWindSpeed"].cast("float"),
     "origin_HourlyDewPointTemperature": df_interpolate["origin_HourlyDewPointTemperature"].cast("float"),
     "origin_HourlyDryBulbTemperature": df_interpolate["origin_HourlyDryBulbTemperature"].cast("float"),
     "origin_HourlyPressureChange": df_interpolate["origin_HourlyPressureChange"].cast("float"),
     "origin_HourlyRelativeHumidity": df_interpolate["origin_HourlyRelativeHumidity"].cast("float"),
     "origin_HourlyWetBulbTemperature": df_interpolate["origin_HourlyWetBulbTemperature"].cast("float"),
     "origin_HourlyVisibility": df_interpolate["origin_HourlyVisibility"].cast("float")
     
     
     })

In [0]:


def smooth_column_optimized(
    df, 
    col_name, 
    station_col="origin_STATION", 
    dt_col="sched_depart_date_time", 
    alpha=0.5, 
    window_size=6
):
    """Applies exponential smoothing using vectorized Pandas UDF"""
    
    # 1. Define window to collect past non-null values (per station)
    window_spec = Window.partitionBy(station_col) \
                       .orderBy(F.col(dt_col).cast("long")) \
                       .rowsBetween(-window_size, 0)

    # 2. Collect non-null values within the window
    df = df.withColumn(
        f"non_null_{col_name}",
        F.collect_list(F.col(col_name)).over(window_spec)
    )

    # 3. Vectorized Pandas UDF for exponential smoothing
    
    @pandas_udf(DoubleType())
    def exponential_smoothing_pandas(values: pd.Series) -> pd.Series:
        #if all values are empty,
        if values.empty or not pd.api.types.is_numeric_dtype(values):
            return pd.Series([0.0] * len(values))
        #compute exponential moving average
        return values.ewm(alpha=0.5, ignore_na=True).mean()

    # 4. Apply smoothing and fill nulls
    return (
        df.withColumn(f"smoothed_{col_name}", exponential_smoothing_pandas(F.col(f"non_null_{col_name}")))
          .withColumn(col_name, F.coalesce(F.col(col_name), F.col(f"smoothed_{col_name}"), F.lit(0.0)))
          .drop(f"non_null_{col_name}", f"smoothed_{col_name}")
    )


In [0]:


@pandas_udf(DoubleType())
def exponential_smoothing_pandas(values: pd.Series) -> pd.Series:
    """Vectorized UDF for exponential smoothing."""
    if values.empty or not pd.api.types.is_numeric_dtype(values):
        return pd.Series([0.0] * len(values))  # Handle edge cases
    return values.ewm(alpha=0.5, ignore_na=True).mean()

def smooth_column_optimized(
    df, 
    col_name, 
    station_col="origin_STATION", 
    dt_col="sched_depart_date_time", 
    window_size=6
):
    """Applies exponential smoothing with proper null handling."""
    
    # 1. Cast to numeric type and filter nulls
    df = (
        df.withColumn(col_name, F.col(col_name).cast(DoubleType()))
    )
    
    # 2. Define window to collect non-null values
    window_spec = (
        Window.partitionBy(station_col)
              .orderBy(F.col(dt_col).cast("long"))
              .rowsBetween(-window_size, 0)
    )
    
    # 3. Collect ONLY non-null values within the window
    df = df.withColumn(
        "non_null_values",
        F.collect_list(col_name).over(window_spec)
    )
    
    # 4. Apply vectorized UDF and fill nulls
    return (
        df.withColumn("smoothed", exponential_smoothing_pandas("non_null_values"))
          .withColumn(
              col_name, 
              F.coalesce(F.col(col_name), F.col("smoothed"), F.lit(0.0))
          )
          .drop("non_null_values", "smoothed")
    )


In [0]:
df_interpolated=smooth_column_optimized(df_interpolate, 'origin_HourlyWindSpeed')

In [0]:
df.count()

In [0]:
df_interpolated.filter(F.col('origin_HourlyWindSpeed').isNull()).count()

In [0]:
df_interpolated.count()

In [0]:
columns_to_smooth = ['origin_HourlyDewPointTemperature','origin_HourlyDryBulbTemperature','origin_HourlyPressureChange','origin_HourlyRelativeHumidity','origin_HourlyWetBulbTemperature','origin_HourlyPrecipitation','origin_HourlyWindGustSpeed']

for col in columns_to_smooth:
    df_interpolated = smooth_column_optimized(df_interpolated, col)



In [0]:
df_interpolated.count()

In [0]:

null_counts = df_interpolated.select(
    [count(when(col(c).isNull() | isnan(c), c)).alias(c) for c in columns_to_smooth]
)

display(null_counts)

In [0]:
df_interpolated.count()

In [0]:


output_path = "dbfs:/student-groups/Group_4_1/interim/join_checkpoints/joined_1y_weather_cleaned_test.parquet"
(
    df_interpolated.write
    .mode("overwrite")
    .parquet(output_path)
)




In [0]:
test = spark.read.parquet(f"{team_BASE_DIR}/interim/join_checkpoints/joined_1y_weather_cleaned_test.parquet")

In [0]:
test.count()