# Setup

In [0]:
!pip install timezonefinder

In [0]:
import pyspark.sql.functions as F
from pyspark.sql.functions import col,isnan, when, count, concat_ws, countDistinct, collect_set, rank, window, avg, hour, udf
import matplotlib.pyplot as plt
import pandas as pd
import re
import numpy as np
from pyspark.sql import types
from pyspark.sql.types import *
from pyspark.sql import Window
from itertools import combinations
from timezonefinder import TimezoneFinder
import pytz
from datetime import datetime


data_BASE_DIR = "dbfs:/mnt/mids-w261/"
display(dbutils.fs.ls(f"{data_BASE_DIR}")) #note the other possible samples we can use like 1 day

In [0]:

ydf = spark.read.parquet(f"dbfs:/mnt/mids-w261/datasets_final_project_2022/parquet_weather_data_1y/")

mshr = pd.read_csv('mshr_standard.txt',sep='\t') #via https://www.ncei.noaa.gov/access/homr/reports

In [0]:
stations = spark.read.parquet(f"dbfs:/mnt/mids-w261/datasets_final_project_2022/stations_data/*")


In [0]:
qdf = spark.read.parquet(f"dbfs:/mnt/mids-w261/datasets_final_project_2022/parquet_weather_data_3m/")
mshr = pd.read_csv('mshr_standard.txt',sep='\t') #via https://www.ncei.noaa.gov/access/homr/reports


## Clean MSHR

In [0]:
def parse_fixed_width(row):
    return {
        "station_id": row[:8].strip(),
        "record_type": row[9:11].strip(),
        "coop_station_id": row[12:18].strip(),
        "climate_division": row[19:21].strip(),
        "WBAN_ID": row[22:27].strip(),
        "WMO_ID": row[28:33].strip(),
        "FAA_ID": row[34:39].strip(),
        "NWS_ID": row[40:44].strip(),
        "ICAO_ID": row[45:49].strip(),
        "country": row[50:70].strip(),
        "state_FIPS": row[71:73].strip(),
        "county": row[74:104].strip(),
        "time_zone": row[105:110].strip(),
        "coop_station_name": row[111:141].strip(),
        "principal_station_name": row[142:172].strip(),
        "begin_date": row[173:181].strip(),
        "end_date": row[182:191].strip(),
        "lat_deg": row[192:194].strip(),
        "lat_min": row[195:197].strip(),
        "lat_sec": row[198:200].strip(),
        "lon_deg": row[201:205].strip(),
        "lon_min": row[206:208].strip(),
        "lon_sec": row[209:211].strip(),
        "latlon_precision": row[212:219].strip(),
        "ground_elevation": row[219:225].strip(),
        "elevation_other": row[226:229].strip(),
        "elevation_other_type": row[230:231].strip(),
        "station_relocation": row[232:243].strip(),
        "station_types": row[244:].strip(),
    }


parsed_rows = [parse_fixed_width(mshr.iloc[i, 0]) for i in range(len(mshr))]

max_cols = max(len(row) for row in parsed_rows)

mshr_parse = pd.DataFrame(parsed_rows)

In [0]:
territory_FIPS = ['AS','GU','MP','PR','UM','VI']  #via https://www.census.gov/library/reference/code-lists/ansi.html#states

mshr_parse['end_date'] = mshr_parse['end_date'].apply(
    lambda x: pd.to_datetime('2025' + x[4:]) if x.startswith('9999') else pd.to_datetime(x)
)

mshr_parse['begin_date']=pd.to_datetime(mshr_parse['begin_date'], errors='coerce')

mshr_parse['latlon_precision']=mshr_parse['latlon_precision'].apply(lambda x: x[:3].strip(' '))

mshr_parse['lat_deg']=pd.to_numeric(mshr_parse['lat_deg'])
mshr_parse['lat_min']=pd.to_numeric(mshr_parse['lat_min'],errors='raise')
mshr_parse['lat_sec']=pd.to_numeric(mshr_parse['lat_sec'],errors='raise')

mshr_parse['lon_deg']=pd.to_numeric(mshr_parse['lon_deg'],errors='raise')
mshr_parse['lon_min']=pd.to_numeric(mshr_parse['lon_min'],errors='raise')
mshr_parse['lon_sec']=pd.to_numeric(mshr_parse['lon_sec'],errors='raise')

mshr_parse['lat_dd']= mshr_parse.apply(
    lambda row: (row['lat_deg'] + row['lat_min']/60 + row['lat_sec']/3600), 
                axis=1)

mshr_parse['lon_dd']= mshr_parse.apply(
    lambda row: (row['lon_deg'] + row['lon_min']/60 + row['lon_sec']/3600), 
                axis=1)


mshr_us= mshr_parse[(mshr_parse['state_FIPS'].isin(territory_FIPS)) | (mshr_parse['country'] == 'UNITED STATES')]

mshr_us = mshr_us[mshr_us['end_date'] >= pd.to_datetime('20150101')] #only stations that were functioning at the time of flights dataset

mshr_df = spark.createDataFrame(mshr_us)

#rename MSHR columns to match YDF where relevant
mshr_df=mshr_df.withColumnsRenamed({'lat_dd':'LATITUDE', 
                                    'lon_dd':'LONGITUDE', 
                                    'principal_station_name':'NAME',
                                    'country':'COUNTRY'})





In [0]:
mshr_df.columns

# Location Nulls

In [0]:
ydf=ydf.withColumn("WBAN",F.col('STATION').substr(-5, 5))
ydf_country=ydf.withColumn("COUNTRY",F.col('NAME').substr(-2, 2))

In [0]:

#make ICAO lookup table for when we can't match the WBAN to MSHR df to find the geoloc
ICAO_lookup = ydf_country.filter(F.col('REM').isNotNull()) \
                        .filter(
                            (F.col('REM').contains('METAR')) |
                            (F.col('REM').contains('SPECI'))
                         ) \
                        .withColumn("row_num", 
                                    F.row_number().over(
                                        Window.partitionBy('STATION')
                                        .orderBy(F.col("STATION").desc())
                                    )) \
                        .filter(F.col('row_num') == 1) \
                        .drop(F.col('row_num')) \
                        .withColumn("ICAO_ID", 
                                    F.regexp_extract(F.col("REM"), 
                                                     r"(?:METAR|SPECI)\s(\S+)", 1))


display(ICAO_lookup)

In [0]:
display(ICAO_lookup)

In [0]:
ydf_country=ydf_country.join(ICAO_lookup.select('STATION','ICAO_ID'), on='STATION', how='left_outer')
display(ydf_country)

In [0]:
missing_loc = ydf_country \
        .filter(F.col("LATITUDE").isNull()) \
        .select(F.expr("* EXCEPT(LATITUDE, LONGITUDE, NAME)")) #will replace those cols w MSHR fill ins


join_condition = (
    ((missing_loc["WBAN"] == mshr_df["WBAN_ID"]) |
     (missing_loc["ICAO_ID"] == mshr_df["ICAO_ID"])
    ) &
    (missing_loc["DATE"] <= mshr_df["end_date"])
)

window_spec = Window \
                    .partitionBy(missing_loc["WBAN"], missing_loc["DATE"]) \
                    .orderBy(F.col("begin_date").desc() #most recent begin date
)
                    
#deduplicated missing loc rows: based on EDA showing all dups in these null rows are SOD with nulls

dedup = (missing_loc.groupBy("WBAN", "DATE").count().filter(F.col("count") > 1)) \
            .join(missing_loc, on=["WBAN", "DATE"], how="inner") \
            .filter(~F.col("REPORT_TYPE").contains("SOD"))

nondup = missing_loc.join(dedup, on=["WBAN", "DATE"], how="left_anti")

missing_loc_clean = dedup.select(F.expr("* EXCEPT(COUNT)")).union(nondup)


missing_loc_result = missing_loc_clean \
    .join(mshr_df, join_condition, "left_outer") \
    .withColumn("row_num", F.row_number().over(window_spec)) \
    .filter(F.col("row_num") == 1) \
    .drop("row_num")


display(missing_loc_result)

In [0]:
missing_loc_result.count()

In [0]:
no_mshr = missing_loc_result.filter(F.col('NAME').isNull())
no_mshr.count()
no_mshr.select('WBAN').distinct().show()

In [0]:
numeric_cols = ['HourlyAltimeterSetting',
 'HourlyDewPointTemperature',
 'HourlyDryBulbTemperature',
 'HourlyPrecipitation',
 'HourlyPresentWeatherType',
 'HourlyPressureChange',
 'HourlyPressureTendency',
 'HourlyRelativeHumidity',
 'HourlySkyConditions',
 'HourlySeaLevelPressure',
 'HourlyStationPressure',
 'HourlyVisibility',
 'HourlyWetBulbTemperature',
 'HourlyWindDirection',
 'HourlyWindGustSpeed',
 'HourlyWindSpeed',
 'DailyAverageDewPointTemperature',
 'DailyAverageDryBulbTemperature',
 'DailyAverageRelativeHumidity',
 'DailyAverageSeaLevelPressure',
 'DailyAverageStationPressure',
 'DailyAverageWetBulbTemperature',
 'DailyAverageWindSpeed',
 'DailyCoolingDegreeDays',
 'DailyDepartureFromNormalAverageTemperature',
 'DailyHeatingDegreeDays',
 'DailyMaximumDryBulbTemperature',
 'DailyMinimumDryBulbTemperature',
 'DailyPeakWindDirection',
 'DailyPeakWindSpeed',
 'DailyPrecipitation',
 'DailySnowDepth',
 'DailySnowfall',
 'DailySustainedWindDirection',
 'DailySustainedWindSpeed',
 'DailyWeather',
 'MonthlyAverageRH',
 'MonthlyDaysWithGT001Precip',
 'MonthlyDaysWithGT010Precip',
 'MonthlyDaysWithGT32Temp',
 'MonthlyDaysWithGT90Temp',
 'MonthlyDaysWithLT0Temp',
 'MonthlyDaysWithLT32Temp',
 'MonthlyDepartureFromNormalAverageTemperature',
 'MonthlyDepartureFromNormalCoolingDegreeDays',
 'MonthlyDepartureFromNormalHeatingDegreeDays',
 'MonthlyDepartureFromNormalMaximumTemperature',
 'MonthlyDepartureFromNormalMinimumTemperature',
 'MonthlyDepartureFromNormalPrecipitation',
 'MonthlyDewpointTemperature',
 'MonthlyGreatestPrecip',
 'MonthlyGreatestPrecipDate',
 'MonthlyGreatestSnowDepth',
 'MonthlyGreatestSnowDepthDate',
 'MonthlyGreatestSnowfall',
 'MonthlyGreatestSnowfallDate',
 'MonthlyMaxSeaLevelPressureValue',
 'MonthlyMaxSeaLevelPressureValueDate',
 'MonthlyMaxSeaLevelPressureValueTime',
 'MonthlyMaximumTemperature',
 'MonthlyMeanTemperature',
 'MonthlyMinSeaLevelPressureValue',
 'MonthlyMinSeaLevelPressureValueDate',
 'MonthlyMinSeaLevelPressureValueTime',
 'MonthlyMinimumTemperature',
 'MonthlySeaLevelPressure',
 'MonthlyStationPressure',
 'MonthlyTotalLiquidPrecipitation',
 'MonthlyTotalSnowfall',
 'MonthlyWetBulb',
  'ShortDurationEndDate005',
 'ShortDurationEndDate010',
 'ShortDurationEndDate015',
 'ShortDurationEndDate020',
 'ShortDurationEndDate030',
 'ShortDurationEndDate045',
 'ShortDurationEndDate060',
 'ShortDurationEndDate080',
 'ShortDurationEndDate100',
 'ShortDurationEndDate120',
 'ShortDurationEndDate150',
 'ShortDurationEndDate180',
 'ShortDurationPrecipitationValue005',
 'ShortDurationPrecipitationValue010',
 'ShortDurationPrecipitationValue015',
 'ShortDurationPrecipitationValue020',
 'ShortDurationPrecipitationValue030',
 'ShortDurationPrecipitationValue045',
 'ShortDurationPrecipitationValue060',
 'ShortDurationPrecipitationValue080',
 'ShortDurationPrecipitationValue100',
 'ShortDurationPrecipitationValue120',
 'ShortDurationPrecipitationValue150',
 'ShortDurationPrecipitationValue180']

In [0]:
overlap=missing_loc_result.join(loc_result,on='STATION', how='inner')
display(overlap) #double checking there's no information we can copy over

In [0]:
from functools import reduce
from operator import add



sum_expression = reduce(add, [F.when(F.col(c).isNull(), 1).otherwise(0) for c in numeric_cols])

_loc = ydf_country \
    .filter(F.col("LATITUDE").isNotNull()) \
    .filter(F.col('COUNTRY') == 'US') \
    .withColumn(
        "null_count", 
        sum_expression
    )

loc_result = _loc \
    .withColumn("row_num", 
                F.row_number().over(
                    Window \
                        .partitionBy("STATION", "DATE") \
                        .orderBy(F.col("null_count").asc())
                    )
                ) \
    .filter(F.col("row_num") == 1) \
    .drop(F.col('row_num')) \
    .drop(F.col('loc_result'))

display(loc_result)

In [0]:
loc_result = loc_result.drop(F.col('null_count'))

In [0]:
loc_result.count()

In [0]:
loc_result.count()

In [0]:
display(missing_loc_result)

In [0]:
loc_result.columns

In [0]:
df_cols = missing_loc_result.columns

# get index of the duplicate columns
duplicate_col_index = list(set([df_cols.index(c) for c in df_cols if df_cols.count(c) == 2]))

# rename by adding suffix '_duplicated'
for i in duplicate_col_index:
    df_cols[i] = df_cols[i] + '_duplicated'


# remove flagged columns
cols_to_remove = [c for c in df_cols if '_duplicated' in c]

missing_loc_result = missing_loc_result.toDF(*df_cols)



In [0]:
display(missing_loc_result)

In [0]:
missing_loc_result=missing_loc_result.drop(*cols_to_remove)

In [0]:
display(loc_result)

In [0]:
full = missing_loc_result \
        .select(
            *[F.col(col) for col in loc_result.columns]
        ) \
        .union(loc_result)

In [0]:
full.filter(F.col('LATITUDE').isNull()).count()

In [0]:
display(full)

In [0]:
checkpoint_path = "dbfs:/student-groups/Group_4_1/interim/weather_1y_checkpoint"


spark.sparkContext.setCheckpointDir(checkpoint_path)

weather_ydf_checkpointed = full.checkpoint(eager=True)
weather_ydf_checkpointed.write.mode('overwrite').parquet(checkpoint_path)


# Interpolation 

## Time grid

In [0]:
checkpoint_path = "dbfs:/student-groups/Group_4_1/interim/weather_1y_checkpoint"

ydf_ = spark.read.parquet(checkpoint_path)
display(ydf_)

In [0]:

#prepare for time lookup, UTC conversion and time grid interpolation
ydf_ = ydf_.withColumn('LATITUDE', ydf_.LATITUDE.cast(types.DoubleType()))
ydf_ = ydf_.withColumn('LONGITUDE', ydf_.LONGITUDE.cast(types.DoubleType()))
ydf_ = ydf_.withColumn('DATETIME', F.regexp_replace(F.col('DATE'), 'T', ' '))
ydf_ = ydf_.withColumn('DATETIME', F.to_timestamp('DATETIME', 'yyyy-MM-dd HH:mm:ss'))

display(ydf_)

In [0]:

#get unique stations so we can use their lat/lon for a lookup table
ydf_uniq = ydf_  \
            .withColumn("row_num", 
                    F.row_number().over(
                        Window.partitionBy('STATION')
                        .orderBy(F.col("STATION").desc())
                    )) \
    .filter(F.col('row_num') == 1) \
    .drop(F.col('row_num'))

In [0]:
#code was used to create a lookup table, but this was saved to parquet and should not be run again on the 1y dataset

def find_timezone(lat, lng):
    tf = TimezoneFinder()
    timezone_str = tf.timezone_at(lat=lat, lng=lng)
    return timezone_str if timezone_str else "Unknown"

find_timezone_udf = udf(find_timezone, StringType())


ydf_tz = ydf_uniq.withColumn("timezone", find_timezone_udf(col("LATITUDE"), col("LONGITUDE")))
folder_path = "dbfs:/student-groups/Group_4_1"
ydf_tz.write.parquet(f"{folder_path}/external/weather_tz_lookup.parquet")

In [0]:
ydf_tz.write.mode('overwrite').parquet(f"{folder_path}/external/weather_tz_lookup.parquet")

In [0]:
display(ydf_tz.limit(10))

In [0]:

ydf_time = ydf_.join(ydf_tz.select('STATION','timezone'), ['STATION'], 'left_outer')

In [0]:


def get_utc(datetime_str, timezone_str):
    '''Using timezone information, get localized time of the datetime col then convert to UTC'''
    t = pytz.timezone(timezone_str)
    dt = datetime.strptime(datetime_str, '%Y-%m-%d %H:%M:%S')
    local_dt = t.localize(dt)
    utc_dt = local_dt.astimezone(pytz.utc)
    return utc_dt

get_utc_udf = udf(get_utc, TimestampType())

ydf_time = ydf_time.withColumn('DATETIME', F.to_timestamp(F.col('DATETIME').cast('string')))
ydf_time = ydf_time.withColumn('utc_datetime', get_utc_udf(F.col('DATETIME').cast('string'), F.col('timezone')))

display(ydf_time)

In [0]:
checkpoint_path = "dbfs:/student-groups/Group_4_1/interim/weather_1y_checkpoint"


spark.sparkContext.setCheckpointDir(checkpoint_path)

weather_ydf_checkpointed = ydf_time.checkpoint(eager=True)
weather_ydf_checkpointed.write.mode('overwrite').parquet(checkpoint_path)


## strategy 1: METAR

https://www.weather.gov/media/wrh/mesowest/metar_decode_key.pdf

In [0]:
checkpoint_path = "dbfs:/student-groups/Group_4_1/interim/weather_1y_checkpoint"

ydf_time = spark.read.parquet(f"{checkpoint_path}")


### Precipitation

**3- AND 6-HOUR PRECIPITATION AMOUNT:** 
- 6RRRR; precipitation amount in .01 inches 
  - for past 6 hours reported in 00, 06, 12, and 18 UTC observations
  - for past 3 hours in 03, 09, 15, and 21 UTC observations;
- a trace is 60000. 

**HOURLY PRECIPITATION AMOUNT:** 
- Prrrr; in .01 inches since last METAR; 
- a trace is P0000. 

**WEATHER PHENOMENA**
- RA: liquid precipitation that does not freeze;
- SN: frozen precipitation other than hail; UP: precipitation of unknown type; 
  - intensity prefixed to precipitation: light (-), moderate (no sign), heavy (+); FG: fog; FZFG: freezing fog (temperature below 0°C); 
- BR: mist; 
- HZ: haze; 
- SQ: squall

maximum of three groups reported; augmented by observer: 
- FC (funnel cloud/tornado/waterspout);  - TS (thunderstorm); 
- GR (hail); 
- GS (small hail; < 1/4 inch); 
- FZRA (intensity; freezing rain); 
- VA (volcanic ash). 


**24-HOUR PRECIPITATION AMOUNT:**

- 7R24R24R24R24; precipitation amount in .01 inches for past 24 hours reported in 12 UTC observation, e.g., 70015. 

In [0]:
ydf_time.groupBy(F.col('HourlyPrecipitation')).count().show()

In [0]:
ydf_time.filter(F.col('HourlyPrecipitation').isNull()).count() #pre-REM interpolation

In [0]:
#extract information from REM column

ydf_metar = ydf_time \
    .withColumn(
        'HourlyPrecipitation',
        F.when(
            (F.col("HourlyPrecipitation").isNull()) | (F.col("HourlyPrecipitation") == '*'),
            (F.regexp_extract(F.col("REM"), r" P(\d+)", 1).cast("int") * 0.01) # hundredths of inch kept in "remarks" section
        ).otherwise(F.col("HourlyPrecipitation"))
    ) \
    .withColumn('HourlyPrecipitation', F.regexp_replace('HourlyPrecipitation', 'T', '0.01')) \
    .withColumn(
        'HourlyPrecipitation',
        F.regexp_extract('HourlyPrecipitation', r"[0-9]+(\.[0-9]+)?", 0) # Match digits
    ) \
    .withColumn('HourlyPrecipitation', F.col('HourlyPrecipitation').cast(DoubleType()))
    


In [0]:
ydf_metar.filter(F.col('HourlyPrecipitation').isNull()).count()
display(ydf_metar.filter(F.col('HourlyPrecipitation').isNull()))

In [0]:
ydf_metar.filter(F.col('HourlyPrecipitation').isNull()) \
    .filter(~F.col('REM').contains(' P')) \
    .filter(F.col('HourlyPresentWeatherType').isNull()).count()

#basically of (categorical) precipitation nulls we will have to interpolate
#use: categorical weighted by continuous?


In [0]:
ydf_metar.filter(F.col('HourlyPrecipitation').isNull()).count() #post-REM interpolation

In [0]:
ydf_metar = ydf_metar.withColumn('HourlyPrecipitation',
                     ydf_metar.HourlyPrecipitation.cast(DoubleType()))

In [0]:
window_spec = Window.partitionBy('STATION').orderBy("utc_datetime")

ydf_interpolate = ydf_metar \
    .withColumn("prev_value", F.lag("HourlyPrecipitation").over(window_spec)) \
    .withColumn("next_value", F.lead("HourlyPrecipitation").over(window_spec)) \
    .withColumn("prev_time", F.lag("utc_datetime").over(window_spec)) \
    .withColumn("next_time", F.lead("utc_datetime").over(window_spec)) \
    .withColumn("HourlyPrecipitation", 
        F.when(
            (F.col("HourlyPrecipitation").isNull()) &  #current record null
            (F.col("prev_value").isNotNull()) &  #t-1 not null
            (F.col("next_value").isNotNull()),  #t+1 not null
            (F.col("prev_value") + F.col("next_value")) / 2  #linear interpolation (average)
        ).otherwise(
            F.col("HourlyPrecipitation")
            ) \
    )


In [0]:
ydf_interpolate.filter(F.col('HourlyPrecipitation').isNull()).count()
#count that will need to be geo-interpolated

In [0]:
ydf_interpolate=ydf_interpolate.drop('prev_value', 'next_value', 'prev_time', 'next_time')

In [0]:
ydf_interpolate = ydf_interpolate.join(stations.select('station_id','neighbor_id'),
                     on=(ydf_interpolate.STATION == stations.station_id),
                     how='left_outer'
                     )


In [0]:
display(stations.head(10))

In [0]:
missing = ydf_interpolate.filter(F.col("HourlyPrecipitation").isNull()).select('STATION','utc_datetime','HourlyPrecipitation','neighbor_id')

In [0]:
display(ydf_interpolate.limit(10))

In [0]:
display(ydf_interpolate.limit(10))

In [0]:
with_neighbor = missing.alias("m").join(
    stations.alias("s"),
    F.col("m.STATION") == F.col("s.station_id"),
    "left"
).select(
    "m.STATION", 
    "m.utc_datetime",
    "m.neighbor_id",
    F.col("s.neighbor_id").alias("target_neighbor")
)

In [0]:
# complete = ydf_interpolate.filter(F.col("HourlyPrecipitation").isNotNull()).select('STATION','utc_datetime','HourlyPrecipitation','neighbor_id')


In [0]:

imputed = (
    missing.alias("orig")
    .join(
        complete.alias("nn"),
        F.col("orig.neighbor_id") == F.col("nn.STATION"),
        "left"
    )
    .withColumn("time_diff", F.abs(F.col("orig.utc_datetime").cast("long") 
                                 - F.col("nn.utc_datetime").cast("long")))
    .withColumn("rank", F.row_number().over(
        Window.partitionBy("orig.STATION", "orig.utc_datetime")
              .orderBy(F.asc("time_diff"))
    ))
    .filter(F.col("rank") == 1)
    .select(  # Explicit column selection
        [F.col(f"orig.{c}") for c in missing.columns if c != "HourlyPrecipitation"] +
        [F.coalesce("orig.HourlyPrecipitation", "nn.HourlyPrecipitation").alias("HourlyPrecipitation")]
    )
)

# Combine with original complete data
final_df = imputed.unionByName(complete.select(imputed.columns))


In [0]:
# imputed = (ydf_interpolate
#     .alias("orig")
#     .join(ydf_interpolate.alias("nn"),
#           F.col("orig.neighbor_id") == F.col("nn.STATION"),
#           "left")
#     .withColumn("time_diff", F.abs(
#         F.col("orig.utc_datetime").cast("long") - 
#         F.col("nn.utc_datetime").cast("long")))
#     .withColumn("rank", F.row_number().over(
#         Window.partitionBy("orig.STATION", "orig.utc_datetime")
#               .orderBy("time_diff")))
#     .filter(F.col("rank") == 1)
#     .withColumn("HourlyPrecipitation",
#                 F.coalesce("orig.HourlyPrecipitation", 
#                            "nn.HourlyPrecipitation"))
# )


# display(imputed)


In [0]:
stations.count()

In [0]:
# 2. Get neighbor relationships from stations table
with_neighbor = missing.alias("m").join(
    stations.alias("s"),
    F.col("m.STATION") == F.col("s.station_id"),
    "inner"
).select(
    "m.STATION", 
    "m.utc_datetime",
    "m.neighbor_id",  # Explicitly select needed columns
    F.col("s.neighbor_id").alias("target_neighbor")
)

# 3. Find closest temporal match
imputed = (
    with_neighbor.alias("w")
    .join(
        ydf_interpolate.alias("n"),
        (F.col("w.target_neighbor") == F.col("n.STATION")) & 
        (F.abs(F.col("w.utc_datetime").cast("long") - 
               F.col("n.utc_datetime").cast("long")) <= 86400),
        "inner"
    )
    .withColumn("time_diff", F.abs(F.col("w.utc_datetime") - F.col("n.utc_datetime")))
    .withColumn("rnk", F.row_number().over(
        Window.partitionBy("w.STATION", "w.utc_datetime")
              .orderBy(F.asc("time_diff"))
    ))
    .filter(F.col("rnk") == 1)
    .select(
        "w.STATION",
        "w.utc_datetime",
        "w.neighbor_id",
        F.coalesce(F.col("n.HourlyPrecipitation")).alias("HourlyPrecipitation")  # Directly use neighbor's value
    )
)

# # 4. Schema-safe combination
# final_columns = ["STATION", "utc_datetime", "neighbor_id", "HourlyPrecipitation"]

# final_df = ydf_interpolate.filter(F.col("HourlyPrecipitation").isNotNull()) \
#                          .select(final_columns) \
#                          .unionByName(imputed.select(final_columns))


In [0]:
display(final_df.limit(10))

### Wind

Direction in tens of degrees from true north (first three digits); next two digits: speed in whole knots; as needed Gusts (character) followed by maximum observed speed; always appended with KT to indicate knots; 00000KT for calm; if direction varies by 60° or more a Variable wind direction group is reported. 


PEAK WIND: PK WND dddff(f)/(hh)mm; direction in tens of degrees, speed in whole knots, and time. 

WIND SHIFT: WSHFT (hh)mm 

In [0]:
display(ydf_metar.filter(F.col('HourlyWindSpeed').isNull()).filter(F.col('REM').contains('WND')))

In [0]:
checkpoint_path = "dbfs:/student-groups/Group_4_1/interim/weather_1y_interpolation_base"


spark.sparkContext.setCheckpointDir(checkpoint_path)

weather_ydf_checkpointed = ydf_interpolate.checkpoint(eager=True)
weather_ydf_checkpointed.write.mode('overwrite').parquet(checkpoint_path)


# QDF Prep

In [0]:
qdf=qdf.withColumn("WBAN",F.col('STATION').substr(-5, 5))
qdf_country=qdf.withColumn("COUNTRY",F.col('NAME').substr(-2, 2))

In [0]:

#make ICAO lookup table for when we can't match the WBAN to MSHR df to find the geoloc
ICAO_qlookup = qdf_country.filter(F.col('REM').isNotNull()) \
                        .filter(
                            (F.col('REM').contains('METAR')) |
                            (F.col('REM').contains('SPECI'))
                         ) \
                        .withColumn("row_num", 
                                    F.row_number().over(
                                        Window.partitionBy('STATION')
                                        .orderBy(F.col("STATION").desc())
                                    )) \
                        .filter(F.col('row_num') == 1) \
                        .drop(F.col('row_num')) \
                        .withColumn("ICAO_ID", 
                                    F.regexp_extract(F.col("REM"), 
                                                     r"(?:METAR|SPECI)\s(\S+)", 1))


display(ICAO_qlookup)

In [0]:
qdf_country=qdf_country.join(ICAO_qlookup.select('STATION','ICAO_ID'), on='STATION', how='left_outer')
display(qdf_country)

In [0]:
missing_loc = qdf_country \
        .filter(F.col("LATITUDE").isNull()) \
        .select(F.expr("* EXCEPT(LATITUDE, LONGITUDE, NAME)")) #will replace those cols w MSHR fill ins


join_condition = (
    ((missing_loc["WBAN"] == mshr_df["WBAN_ID"]) |
     (missing_loc["ICAO_ID"] == mshr_df["ICAO_ID"])
    ) &
    (missing_loc["DATE"] <= mshr_df["end_date"])
)

window_spec = Window \
                    .partitionBy(missing_loc["WBAN"], missing_loc["DATE"]) \
                    .orderBy(F.col("begin_date").desc() #most recent begin date
)
                    
#deduplicated missing loc rows: based on EDA showing all dups in these null rows are SOD with nulls

dedup = (missing_loc.groupBy("WBAN", "DATE").count().filter(F.col("count") > 1)) \
            .join(missing_loc, on=["WBAN", "DATE"], how="inner") \
            .filter(~F.col("REPORT_TYPE").contains("SOD"))

nondup = missing_loc.join(dedup, on=["WBAN", "DATE"], how="left_anti")

missing_loc_clean = dedup.select(F.expr("* EXCEPT(COUNT)")).union(nondup)


missing_loc_result = missing_loc_clean \
    .join(mshr_df, join_condition, "left_outer") \
    .withColumn("row_num", F.row_number().over(window_spec)) \
    .filter(F.col("row_num") == 1) \
    .drop("row_num")


display(missing_loc_result)

In [0]:
missing_loc_result.count()

In [0]:
no_mshr = missing_loc_result.filter(F.col('NAME').isNull())
no_mshr.count()
no_mshr.select('WBAN').distinct().show()

In [0]:
display(missing_loc_result.filter(F.col('WBAN')=='37201'))

In [0]:
from functools import reduce
from operator import add


In [0]:
numeric_cols = ['HourlyAltimeterSetting',
 'HourlyDewPointTemperature',
 'HourlyDryBulbTemperature',
 'HourlyPrecipitation',
 'HourlyPresentWeatherType',
 'HourlyPressureChange',
 'HourlyPressureTendency',
 'HourlyRelativeHumidity',
 'HourlySkyConditions',
 'HourlySeaLevelPressure',
 'HourlyStationPressure',
 'HourlyVisibility',
 'HourlyWetBulbTemperature',
 'HourlyWindDirection',
 'HourlyWindGustSpeed',
 'HourlyWindSpeed',
 'DailyAverageDewPointTemperature',
 'DailyAverageDryBulbTemperature',
 'DailyAverageRelativeHumidity',
 'DailyAverageSeaLevelPressure',
 'DailyAverageStationPressure',
 'DailyAverageWetBulbTemperature',
 'DailyAverageWindSpeed',
 'DailyCoolingDegreeDays',
 'DailyDepartureFromNormalAverageTemperature',
 'DailyHeatingDegreeDays',
 'DailyMaximumDryBulbTemperature',
 'DailyMinimumDryBulbTemperature',
 'DailyPeakWindDirection',
 'DailyPeakWindSpeed',
 'DailyPrecipitation',
 'DailySnowDepth',
 'DailySnowfall',
 'DailySustainedWindDirection',
 'DailySustainedWindSpeed',
 'DailyWeather',
 'MonthlyAverageRH',
 'MonthlyDaysWithGT001Precip',
 'MonthlyDaysWithGT010Precip',
 'MonthlyDaysWithGT32Temp',
 'MonthlyDaysWithGT90Temp',
 'MonthlyDaysWithLT0Temp',
 'MonthlyDaysWithLT32Temp',
 'MonthlyDepartureFromNormalAverageTemperature',
 'MonthlyDepartureFromNormalCoolingDegreeDays',
 'MonthlyDepartureFromNormalHeatingDegreeDays',
 'MonthlyDepartureFromNormalMaximumTemperature',
 'MonthlyDepartureFromNormalMinimumTemperature',
 'MonthlyDepartureFromNormalPrecipitation',
 'MonthlyDewpointTemperature',
 'MonthlyGreatestPrecip',
 'MonthlyGreatestPrecipDate',
 'MonthlyGreatestSnowDepth',
 'MonthlyGreatestSnowDepthDate',
 'MonthlyGreatestSnowfall',
 'MonthlyGreatestSnowfallDate',
 'MonthlyMaxSeaLevelPressureValue',
 'MonthlyMaxSeaLevelPressureValueDate',
 'MonthlyMaxSeaLevelPressureValueTime',
 'MonthlyMaximumTemperature',
 'MonthlyMeanTemperature',
 'MonthlyMinSeaLevelPressureValue',
 'MonthlyMinSeaLevelPressureValueDate',
 'MonthlyMinSeaLevelPressureValueTime',
 'MonthlyMinimumTemperature',
 'MonthlySeaLevelPressure',
 'MonthlyStationPressure',
 'MonthlyTotalLiquidPrecipitation',
 'MonthlyTotalSnowfall',
 'MonthlyWetBulb',
  'ShortDurationEndDate005',
 'ShortDurationEndDate010',
 'ShortDurationEndDate015',
 'ShortDurationEndDate020',
 'ShortDurationEndDate030',
 'ShortDurationEndDate045',
 'ShortDurationEndDate060',
 'ShortDurationEndDate080',
 'ShortDurationEndDate100',
 'ShortDurationEndDate120',
 'ShortDurationEndDate150',
 'ShortDurationEndDate180',
 'ShortDurationPrecipitationValue005',
 'ShortDurationPrecipitationValue010',
 'ShortDurationPrecipitationValue015',
 'ShortDurationPrecipitationValue020',
 'ShortDurationPrecipitationValue030',
 'ShortDurationPrecipitationValue045',
 'ShortDurationPrecipitationValue060',
 'ShortDurationPrecipitationValue080',
 'ShortDurationPrecipitationValue100',
 'ShortDurationPrecipitationValue120',
 'ShortDurationPrecipitationValue150',
 'ShortDurationPrecipitationValue180']

In [0]:



sum_expression = reduce(add, [F.when(F.col(c).isNull(), 1).otherwise(0) for c in numeric_cols])

_loc = qdf_country \
    .filter(F.col("LATITUDE").isNotNull()) \
    .filter(F.col('COUNTRY') == 'US') \
    .withColumn(
        "null_count", 
        sum_expression
    )

loc_result = _loc \
    .withColumn("row_num", 
                F.row_number().over(
                    Window \
                        .partitionBy("STATION", "DATE") \
                        .orderBy(F.col("null_count").asc())
                    )
                ) \
    .filter(F.col("row_num") == 1) \
    .drop(F.col('row_num')) \
    .drop(F.col('loc_result'))

display(loc_result)

In [0]:
loc_result = loc_result.drop(F.col('null_count'))

In [0]:
loc_result.count()

In [0]:
df_cols = missing_loc_result.columns

# get index of the duplicate columns
duplicate_col_index = list(set([df_cols.index(c) for c in df_cols if df_cols.count(c) == 2]))

# rename by adding suffix '_duplicated'
for i in duplicate_col_index:
    df_cols[i] = df_cols[i] + '_duplicated'


# remove flagged columns
cols_to_remove = [c for c in df_cols if '_duplicated' in c]

missing_loc_result = missing_loc_result.toDF(*df_cols)

missing_loc_result=missing_loc_result.drop(*cols_to_remove)

In [0]:
full = missing_loc_result \
        .select(
            *[F.col(col) for col in loc_result.columns]
        ) \
        .union(loc_result)

In [0]:
full.filter(F.col('LATITUDE').isNull()).count()

In [0]:
display(qdf_uniq)

In [0]:
qdf_.filter(F.col('LATITUDE').isNull()).groupBy('STATION').count().show()

In [0]:

#prepare for time lookup, UTC conversion and time grid interpolation
qdf_ = full.withColumn('LATITUDE', full.LATITUDE.cast(types.DoubleType()))
qdf_ = qdf_.withColumn('LONGITUDE', qdf_.LONGITUDE.cast(types.DoubleType()))
qdf_ = qdf_.withColumn('DATETIME', F.regexp_replace(F.col('DATE'), 'T', ' '))
qdf_ = qdf_.withColumn('DATETIME', F.to_timestamp('DATETIME', 'yyyy-MM-dd HH:mm:ss'))



In [0]:
qdf_=qdf_.filter(F.col('STATION') != '99999937201')

In [0]:
#get unique stations so we can use their lat/lon for a lookup table
qdf_uniq = qdf_  \
            .withColumn("row_num", 
                    F.row_number().over(
                        Window.partitionBy('STATION')
                        .orderBy(F.col("STATION").desc())
                    )) \
    .filter(F.col('row_num') == 1) \
    .drop(F.col('row_num'))


In [0]:


def find_timezone(lat, lng):
    tf = TimezoneFinder()
    timezone_str = tf.timezone_at(lat=lat, lng=lng)
    return timezone_str if timezone_str else "Unknown"

find_timezone_udf = udf(find_timezone, StringType())


In [0]:

qdf_tz = qdf_uniq.withColumn("timezone", 
                             find_timezone_udf(col("LATITUDE"), col("LONGITUDE")))

qdf_time = qdf_.join(qdf_tz.select('STATION','timezone'), ['STATION'], 'left_outer')


In [0]:
display(qdf_tz.limit(10))

In [0]:
qdf_.cache()

In [0]:
display(qdf_)

In [0]:
display(qdf_tz.limit(10))

In [0]:
qdf_tz.cache()

In [0]:
folder_path = "dbfs:/student-groups/Group_4_1"

qdf_tz.write.mode('overwrite').parquet(f"{folder_path}/external/3m_weather_tz_lookup.parquet")

In [0]:


def get_utc(datetime_str, timezone_str):
    '''Using timezone information, get localized time of the datetime col then convert to UTC'''
    t = pytz.timezone(timezone_str)
    dt = datetime.strptime(datetime_str, '%Y-%m-%d %H:%M:%S')
    local_dt = t.localize(dt)
    utc_dt = local_dt.astimezone(pytz.utc)
    return utc_dt

get_utc_udf = udf(get_utc, TimestampType())



In [0]:
qdf_time = qdf_time.withColumn('DATETIME', 
                               F.to_timestamp(F.col('DATETIME').cast('string')))
qdf_time = qdf_time.withColumn('utc_datetime', 
                               get_utc_udf(F.col('DATETIME').cast('string'), F.col('timezone')))



In [0]:
display(qdf_time)

In [0]:
checkpoint_path = "dbfs:/student-groups/Group_4_1/interim/weather_3m_checkpoint"


spark.sparkContext.setCheckpointDir(checkpoint_path)

weather_qdf_checkpointed = qdf_time.checkpoint(eager=True)
weather_qdf_checkpointed.write.mode('overwrite').parquet(checkpoint_path)


# Sanity Check Sandbox

In [0]:

display(ydf_interpolate.filter(F.col('STATION')=='70219599999'))
display(ydf_time.filter(F.col('STATION')=='70219599999'))

In [0]:
display(ydf_metar.filter(F.col('HourlyPrecipitation').isNull() &
                 (F.col('REM').contains(' P000')))
                .select('REM')
                )

In [0]:
#try to find a rule for selecting duplicates

#duplicate NAME-DATE pairs
duplicates = (
    ydf_country
    .filter(F.col('REPORT_TYPE') == 'FM-15')
    .filter(F.col('COUNTRY') == 'US')
    .groupBy('NAME', 'DATE')
    .count()
    .filter(F.col("count") > 1)
    .select("NAME", "DATE")
)

#sources for each duplicate NAME-DATE as an array
source_combinations = (
    ydf_country
    .join(duplicates, on=["NAME", "DATE"], how="inner")
    .groupBy("NAME", "DATE")
    .agg(F.collect_set("SOURCE").alias("source_list"))
)

#get unique (SOURCE1, SOURCE2) pairs

def generate_pairs(source_list):
    return [tuple(sorted(pair)) for pair in combinations(source_list, 2)] if len(source_list) > 1 else []

generate_pairs_udf = F.udf(generate_pairs, "array<struct<source1:string, source2:string>>")

source_pairs = source_combinations.withColumn("source_pairs", generate_pairs_udf(F.col("source_list")))

#explode the pairs & count occurrences per NAME
pair_counts = (
    source_pairs
    .select("NAME", F.explode("source_pairs").alias("pair"))
    .groupBy("NAME", "pair")
    .count()
)

#collapse to show only name & pair tally
final_result = (
    pair_counts
    .groupBy("NAME")
    .agg(F.collect_list(F.struct("pair", "count")).alias("source_tally"))
)

display(final_result)
