# Weather Cleanup

## Imports and setup

In [0]:
!pip install timezonefinder

In [0]:
import pyspark.sql.functions as F
from pyspark.sql import Window
from timezonefinder import TimezoneFinder
import pytz

## Variables and Directories

In [0]:
# Variables and directories
data_BASE_DIR = "dbfs:/mnt/mids-w261/datasets_final_project_2022"
team_BASE_DIR = f"dbfs:/student-groups/Group_4_1"
spark.sparkContext.setCheckpointDir(f"{team_BASE_DIR}/checkpoints")
period = "" # one of the following values ("", "_3m", "_6m", "_1y")

# Datasets
weather = spark.read.parquet(f"dbfs:/mnt/mids-w261/datasets_final_project_2022/parquet_weather_data{period}/")
stations_tz = spark.read.parquet(f"{team_BASE_DIR}/external/station_tz_lookup.parquet")

In [0]:
print(weather.count())
print(len(weather.columns))
print(stations_tz.count())

In [0]:
display(dbutils.fs.ls(f"{data_BASE_DIR}"))
display(dbutils.fs.ls(f"{team_BASE_DIR}"))

In [0]:
# Weather Dataset quick look
display(weather)
print(weather.count())
print(len(weather.columns))

## Step 1: Clean Weather Data

### 1. FILTER ONLY USA LOCATIONS

In [0]:
weather_us = (
  weather
  .withColumn("WBAN",F.col('STATION').substr(-5, 5))
  .withColumn("COUNTRY",F.col('NAME').substr(-2, 2))
  .filter(F.col('COUNTRY')=='US')
  )

weather_us.cache()
weather_us = weather_us.checkpoint()

In [0]:
weather_us.count()

In [0]:
features = ['HourlyDewPointTemperature',
 'HourlyDryBulbTemperature',
 'HourlyPrecipitation',
 'HourlyPresentWeatherType',
 'HourlyPressureChange',
 'HourlyPressureTendency',
 'HourlyRelativeHumidity',
 'HourlySkyConditions',
 'HourlySeaLevelPressure',
 'HourlyStationPressure',
 'HourlyVisibility',
 'HourlyWetBulbTemperature',
 'HourlyWindDirection',
 'HourlyWindGustSpeed',
 'HourlyWindSpeed',
 'REM']

### 2. Checking for Duplicates

In [0]:
# Duplicate rows based on Station and date
duplicate_count = (
    weather_us
    .groupBy("STATION", "DATE")
    .agg(F.count("*").alias("count"))
    .filter("count > 1")
    .agg({"count": "sum"})
    .collect()[0][0]
)

print(f"Number of duplicate rows: {duplicate_count}")

### 3. REMOVE DUPLICATES
Deduplicating based on STAION and DATE and only removing the rows with the most nulls.

In [0]:
from functools import reduce
from operator import add

sum_expression = reduce(add, [F.when(F.col(c).isNull(), 1).otherwise(0) for c in features])

weather_dedup = (
    weather_us
    .withColumn(
        "null_count", 
        sum_expression
    )
    .withColumn(
        "row_num",
        F.row_number().over(
            Window
            .partitionBy("STATION", "DATE") \
            .orderBy(F.col("null_count").asc())
        )
    )
    .filter(F.col("row_num") == 1)
    .drop(F.col('row_num'))
    .drop(F.col('null_count'))
)

weather_dedup.cache()
weather_dedup = weather_dedup.checkpoint()

### 3. VALIDATING DUPLICATES ARE REMOVED

In [0]:
# Duplicate rows based on Station and date
duplicate_count = (
    weather_dedup
    .groupBy("STATION", "DATE")
    .agg(F.count("*").alias("count"))
    .filter("count > 1")
    .agg({"count": "sum"})
    .collect()[0][0]
)

print(f"Number of duplicate rows: {duplicate_count}")

## Step 2: Handling Time

### 1. Validate Missing Station ID in Stations data before the join

In [0]:
# Extract unique Stations id with longitude and latitude from weather
station_id = weather_dedup.select("STATION", "LATITUDE", "LONGITUDE").distinct()

# Extract unique Stations id from Station time zone data
tz_station_id = stations_tz.select("STATION").distinct()

# Find missing stations in the time zones table
missing_station = (
    station_id
    .join(tz_station_id, "STATION", "left_anti")
    .withColumn("LATITUDE", F.col('LATITUDE').cast(DoubleType()))
    .withColumn("LONGITUDE", F.col('LONGITUDE').cast(DoubleType()))
    .select("STATION", "LATITUDE", "LONGITUDE").distinct()
).cache()
missing_station_lst = missing_station.select("STATION").rdd.flatMap(lambda x: x).collect()
print(f"Number of missing stations: {len(missing_station_lst)}")

### 2. Finding missing timezones using coordinates

In [0]:
# ============================
# UDF: Timezones Lookup
# ============================
def find_timezone(lat, lng):
    tf = TimezoneFinder()
    timezone_str = tf.timezone_at(lat=lat, lng=lng)
    return timezone_str if timezone_str else "Unknown"

# define udf for time zone looku
find_timezone_udf = udf(find_timezone, StringType())

if len(missing_station_lst) > 0:
    missing_stations_tz = (
        missing_station
        .withColumn("timezone", find_timezone_udf(F.col("LATITUDE"), F.col("LONGITUDE")))
        )
    
    # Augmenting Stations Timezones data with the missing stations data
    stations_tz = stations_tz.union(missing_stations_tz).cache()
    stations_tz = stations_tz.checkpoint()
    # re-save timezones data as a parquet file 
    stations_tz.write.mode("overwrite").parquet(f"{team_BASE_DIR}/external/station_tz_lookup.parquet")

### 3. Finding Time zone using stations time zones helper table

In [0]:
# Register as temporary views for SQL use
weather_dedup.createOrReplaceTempView("weather_dedup")
stations_tz.createOrReplaceTempView("timezones")

# Apply Broadcast Join for small timezones table
tz_broadcast = F.broadcast(stations_tz)

weather_tz = weather_dedup \
    .join(tz_broadcast.alias("a1"), weather_dedup.STATION == F.col("a1.STATION"), "left") \
    .select(
        weather_dedup["*"],
        F.col("a1.timezone").alias("STATION_timezone"),
    )

# Cache the data to avoid recomputing the time zones
weather_tz.cache()
weather_tz = weather_tz.checkpoint()

### 4. Checking if we have any missing stations after the join

In [0]:
# Get statistics for validation: check for null stations_timezone after the join (must be null)

station_tz_match_count = weather_tz.filter(F.col("STATION_timezone").isNotNull()).count()
total_weather_tz = weather_tz.count()

print(f"Stations timeszones match rate: {station_tz_match_count/total_weather_tz:.2%} - {total_weather_tz - station_tz_match_count} timeszones unmatched")

### 5. Converting to UTC

In [0]:
# ============================
# UDF: Convert Deprature Time to UTC
# ============================

def to_utc(dt: str, tz: str) -> str:
    if dt is None:
        return None
    dt_format = "%Y-%m-%dT%H:%M:%S"
    local_dt = datetime.strptime(dt, dt_format)
    if tz:
        timezone = pytz.timezone(tz)
        local_dt = timezone.localize(local_dt)
        
        # Convert to UTC
        utc_dt = local_dt.astimezone(pytz.utc)
        return utc_dt.strftime(dt_format)
    return None

utc_udf = F.udf(to_utc)

weather_utc = (
    weather_tz
    .withColumn("weather_datetime_utc", utc_udf(F.col("DATE"), F.col("STATION_timezone").cast(F.StringType())))
)

# Cache the data to avoid recomputing the UTC
weather_utc.cache()
weather_utc = weather_utc.checkpoint()

In [0]:
display(weather_utc)

In [0]:
weather_utc.count()

In [0]:
# Saving clean weather data with utc
weather_utc.write.mode("overwrite").parquet(f"{team_BASE_DIR}/interim/weather{period}_checkpoint")

## Step 0: Creat the stations time zone helper table for the first time 
_(don't rerun unless necessary)_

In [0]:
# Extract unique Stations id with longitude and latitude from weather
station_tz = (
    weather_dedup
    .select("STATION", "LATITUDE", "LONGITUDE").distinct()
    .withColumn("LATITUDE", F.col('LATITUDE').cast(DoubleType()))
    .withColumn("LONGITUDE", F.col('LONGITUDE').cast(DoubleType()))
    )

# add time zone column
station_tz = station_tz.withColumn("timezone", find_timezone_udf(F.col("LATITUDE"), F.col("LONGITUDE")))

# re-save timezones data as a parquet file 
station_tz.write.mode("overwrite").parquet(f"{team_BASE_DIR}/external/station_tz_lookup.parquet")