In [1]:
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.sql.window import Window

In [2]:
#### Data Preparation Stage

In [3]:
### Helper function design
def load_data_as_sqlview(table_name: str, file_pattern: str, spark: SparkSession):
    # Use the specified file_pattern to read the files
    df = spark.read.option("header", "true").option("inferSchema", "true").csv(file_pattern)
    # Create a temporary SQL view for the DataFrame
    df.createOrReplaceTempView(table_name)
    return df

def run_and_show(spark, query, lines=3):
    result = spark.sql(query)
    result.show(lines)
    return result

In [4]:
# Create spark session
spark = SparkSession.builder.appName("WeatherDataProcessing").getOrCreate()

In [5]:
# Start loading data

In [6]:
station_reading = load_data_as_sqlview(
    table_name="station_reading",
    file_pattern = "/home/jovyan/data/2019/part-*.csv.gz",
    spark=spark)

In [7]:
station_list = load_data_as_sqlview(
    table_name="station_list",
    file_pattern = "/home/jovyan/stationlist.csv",
    spark=spark)

In [8]:
country_list = load_data_as_sqlview(
    table_name="country_list",
    file_pattern = "/home/jovyan/countrylist.csv",
    spark=spark)

In [9]:
# Check data
station_reading.show(3)

+------+-----+--------+----+----+------+------+-----+----+-----+----+-----+-----+-----+-----+------+
|STN---| WBAN|YEARMODA|TEMP|DEWP|   SLP|   STP|VISIB|WDSP|MXSPD|GUST|  MAX|  MIN| PRCP| SNDP|FRSHTT|
+------+-----+--------+----+----+------+------+-----+----+-----+----+-----+-----+-----+-----+------+
| 10260|99999|20190101|26.1|21.2|1001.9| 987.5| 20.6| 9.0| 15.9|29.7| 29.8|21.7*|0.02G| 18.5|  1000|
| 10260|99999|20190102|24.9|22.1|1020.1|1005.5|  5.4| 5.6| 13.6|22.1|27.1*| 20.7|0.48G| 22.8|  1000|
| 10260|99999|20190103|31.7|29.1|1008.9| 994.7| 13.6|11.6| 21.4|49.5|37.4*|26.8*|0.25G|999.9| 11000|
+------+-----+--------+----+----+------+------+-----+----+-----+----+-----+-----+-----+-----+------+
only showing top 3 rows



In [10]:
# rename column name
station_reading = station_reading.withColumnRenamed("STN---", "STN_NO_RD")
station_reading.createOrReplaceTempView('station_reading')

In [11]:
# Check data
station_list.show(3)

+------+------------+
|STN_NO|COUNTRY_ABBR|
+------+------------+
|012240|          NO|
|020690|          SW|
|020870|          SW|
+------+------------+
only showing top 3 rows



In [12]:
# Check data
country_list.show(3)

+------------+-------------------+
|COUNTRY_ABBR|       COUNTRY_FULL|
+------------+-------------------+
|          AA|              ARUBA|
|          AC|ANTIGUA AND BARBUDA|
|          AF|        AFGHANISTAN|
+------------+-------------------+
only showing top 3 rows



In [14]:
# Join the DataFrames using DataFrame operations
station_reading_full = station_reading\
    .join(station_list,  station_list["STN_NO"].cast("int") == station_reading["STN_NO_RD"].cast("int"), how="left")\
    .join(country_list, station_list["COUNTRY_ABBR"] == country_list["COUNTRY_ABBR"], how="left")
station_reading_full.createOrReplaceTempView('station_reading_full')

In [15]:
# Check data
query = "SELECT * FROM station_reading_full LIMIT 3"
run_and_show(spark=spark, query=query)

+---------+-----+--------+----+----+------+------+-----+----+-----+----+-----+-----+-----+-----+------+------+------------+------------+------------+
|STN_NO_RD| WBAN|YEARMODA|TEMP|DEWP|   SLP|   STP|VISIB|WDSP|MXSPD|GUST|  MAX|  MIN| PRCP| SNDP|FRSHTT|STN_NO|COUNTRY_ABBR|COUNTRY_ABBR|COUNTRY_FULL|
+---------+-----+--------+----+----+------+------+-----+----+-----+----+-----+-----+-----+-----+------+------+------------+------------+------------+
|    10260|99999|20190101|26.1|21.2|1001.9| 987.5| 20.6| 9.0| 15.9|29.7| 29.8|21.7*|0.02G| 18.5|  1000|010260|          NO|          NO|      NORWAY|
|    10260|99999|20190102|24.9|22.1|1020.1|1005.5|  5.4| 5.6| 13.6|22.1|27.1*| 20.7|0.48G| 22.8|  1000|010260|          NO|          NO|      NORWAY|
|    10260|99999|20190103|31.7|29.1|1008.9| 994.7| 13.6|11.6| 21.4|49.5|37.4*|26.8*|0.25G|999.9| 11000|010260|          NO|          NO|      NORWAY|
+---------+-----+--------+----+----+------+------+-----+----+-----+----+-----+-----+-----+-----+----

DataFrame[STN_NO_RD: int, WBAN: int, YEARMODA: int, TEMP: double, DEWP: double, SLP: double, STP: double, VISIB: double, WDSP: double, MXSPD: double, GUST: double, MAX: string, MIN: string, PRCP: string, SNDP: double, FRSHTT: int, STN_NO: string, COUNTRY_ABBR: string, COUNTRY_ABBR: string, COUNTRY_FULL: string]

In [16]:
# Check join
query = """
SELECT 
    COUNT(DISTINCT STN_NO) AS slo,
    COUNT(DISTINCT STN_NO_RD) AS sr
FROM station_reading_full;
"""
run_and_show(spark=spark, query=query)

+-----+-----+
|  slo|   sr|
+-----+-----+
|12144|12144|
+-----+-----+



DataFrame[slo: bigint, sr: bigint]

In [17]:
# Check join
query = """
SELECT DISTINCT STN_NO, STN_NO_RD
FROM station_reading_full
WHERE STN_NO_RD IS NOT NULL 
    AND STN_NO IS NOT NULL 
    AND CAST(STN_NO AS STRING) <> CAST(STN_NO_RD AS STRING)
LIMIT 10;
"""
run_and_show(spark=spark, query=query)

+------+---------+
|STN_NO|STN_NO_RD|
+------+---------+
|043390|    43390|
|013800|    13800|
|020950|    20950|
+------+---------+
only showing top 3 rows



DataFrame[STN_NO: string, STN_NO_RD: int]

In [18]:
# Check join
query = """
SELECT DISTINCT STN_NO, STN_NO_RD
FROM station_reading_full
WHERE STN_NO_RD IS NOT NULL 
    AND STN_NO IS NULL
LIMIT 10;
"""
run_and_show(spark=spark, query=query)

+------+---------+
|STN_NO|STN_NO_RD|
+------+---------+
+------+---------+



DataFrame[STN_NO: string, STN_NO_RD: int]

In [19]:
#### Answering Question Stage

In [20]:
### 1. Which country had the hottest average mean temperature over the year?

In [21]:
# Check date, to make sure we only have 1 year of data
query = """
SELECT 
    MAX(YEARMODA) AS max_date,
    MIN(YEARMODA) AS min_date
FROM station_reading_full
LIMIT 3;
"""
run_and_show(spark=spark, query=query)

+--------+--------+
|max_date|min_date|
+--------+--------+
|20200101|20190101|
+--------+--------+



DataFrame[max_date: int, min_date: int]

In [22]:
query = """
WITH mt AS (
    SELECT COUNTRY_FULL, YEARMODA,  AVG(TEMP) AS mean_temp
    FROM station_reading_full
    WHERE 
        YEARMODA BETWEEN 20190101 AND 20191231
        AND TEMP <> 9999.9
    GROUP BY COUNTRY_FULL, YEARMODA
    )
    
SELECT 
    COUNTRY_FULL, 
    AVG(mean_temp) AS avg_mean_temp
FROM mt
GROUP BY COUNTRY_FULL
ORDER BY AVG(mean_temp) DESC
LIMIT 1;
"""
run_and_show(spark=spark, query=query)

+------------+-----------------+
|COUNTRY_FULL|    avg_mean_temp|
+------------+-----------------+
|    DJIBOUTI|90.06114457831323|
+------------+-----------------+



DataFrame[COUNTRY_FULL: string, avg_mean_temp: double]

In [23]:
### 2. Which country had the most consecutive days of tornadoes/funnel cloud formations?

In [24]:
query = """
SELECT LEN(FRSHTT),
    COUNT(1) AS cnt
FROM station_reading_full
GROUP BY LEN(FRSHTT)
ORDER BY LEN(FRSHTT);
"""
run_and_show(spark=spark, query=query, lines=10)

+-----------+-------+
|len(FRSHTT)|    cnt|
+-----------+-------+
|          1|2807103|
|          2|  33583|
|          4| 188837|
|          5| 897708|
|          6| 234103|
+-----------+-------+



DataFrame[len(FRSHTT): int, cnt: bigint]

In [25]:
query = """
SELECT 
    COUNTRY_FULL,
    YEARMODA,
    MAX(CASE WHEN LEN(FRSHTT) = 6 AND SUBSTRING(FRSHTT, 6, 1) = '1' THEN 1 ELSE 0 END) AS tornado
FROM station_reading_full
WHERE COUNTRY_FULL IS NOT NULL
GROUP BY COUNTRY_FULL, YEARMODA
ORDER BY COUNTRY_FULL, YEARMODA ASC
;
"""
df_q3 = run_and_show(spark=spark, query=query)

+------------+--------+-------+
|COUNTRY_FULL|YEARMODA|tornado|
+------------+--------+-------+
| AFGHANISTAN|20190101|      0|
| AFGHANISTAN|20190102|      0|
| AFGHANISTAN|20190103|      0|
+------------+--------+-------+
only showing top 3 rows



In [26]:
def find_consec_tornado_per_country(spark_df):
    country_stat = dict()  # country_name -> max consecutive tornado days
    df = spark_df.collect()  # Convert to a local Pandas DataFrame
    current_country = ""
    tor_days = 0
    for row in df:
        country_name = row['COUNTRY_FULL']
        tor = row['tornado']
        if country_name != current_country:
            if country_name not in country_stat:
                country_stat[country_name] = 0
            if current_country != "":
                country_stat[current_country] = tor_days if tor_days > country_stat[current_country] else country_stat[current_country]
            tor_days = 0
            current_country = country_name
        if tor == 1:
            tor_days += 1
        else:
            country_stat[country_name] = tor_days if tor_days > country_stat[country_name] else country_stat[country_name]
            tor_days = 0
    return country_stat

def find_max_country(country_max):
    max_tor = {
        'country': "",
        'tor_day': -1
    }
    for c_name in country_max:
        if country_max[c_name] > max_tor['tor_day']:
            max_tor['country'] = c_name
            max_tor['tor_day'] = country_max[c_name]
        elif country_max[c_name] == max_tor['tor_day']:
            max_tor['country'] += f", {c_name}"
    return max_tor

country_max = find_consec_tornado_per_country(df_q3)
find_max_country(country_max)

{'country': 'AUSTRIA, BAHAMAS THE, CANADA, COLOMBIA, CUBA, GEORGIA, GHANA, ICELAND, IRAN, ITALY, MADAGASCAR, NEPAL, NORWAY, POLAND, ROMANIA, RUSSIA, SPAIN, TURKEY, UNITED STATES',
 'tor_day': 1}

In [27]:
# To validate the answer, use the following SQL query to see if there's any country that has at least two(2) consecutive tornado days
query = """
WITH ct AS (
    SELECT 
        COUNTRY_FULL,
        YEARMODA,
        MAX(CASE WHEN LEN(FRSHTT) = 6 AND SUBSTRING(FRSHTT, 6, 1) = '1' THEN 1 ELSE 0 END) AS tornado
    FROM station_reading_full
    WHERE COUNTRY_FULL IS NOT NULL
    GROUP BY COUNTRY_FULL, YEARMODA
    ORDER BY COUNTRY_FULL, YEARMODA ASC
    ),
    
    cs AS (
    SELECT 
        COUNTRY_FULL,
        YEARMODA,
        SUM(tornado) OVER (PARTITION BY COUNTRY_FULL ORDER BY YEARMODA ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) AS consec
    FROM ct
    )
    
SELECT MAX(consec) AS max_consec
FROM cs;
"""
run_and_show(spark=spark, query=query)

+----------+
|max_consec|
+----------+
|         1|
+----------+



DataFrame[max_consec: bigint]

In [28]:
### 3. Which country had the second highest average mean wind speed over the year?

In [29]:
query = """
WITH mt AS (
    SELECT COUNTRY_FULL, YEARMODA,  AVG(WDSP) AS mean_wdsp
    FROM station_reading_full
    WHERE 
        YEARMODA BETWEEN 20190101 AND 20191231
        AND WDSP <> 999.9
    GROUP BY COUNTRY_FULL, YEARMODA
    ),
    
    rk AS (SELECT 
        COUNTRY_FULL, 
        AVG(mean_wdsp) AS avg_mean_wdsp,
        ROW_NUMBER() OVER (ORDER BY AVG(mean_wdsp) DESC) AS rank
    FROM mt
    GROUP BY COUNTRY_FULL)

SELECT *
FROM rk
WHERE rank = 2;
"""
run_and_show(spark=spark, query=query)

+------------+------------------+----+
|COUNTRY_FULL|     avg_mean_wdsp|rank|
+------------+------------------+----+
|       ARUBA|15.981917808219182|   2|
+------------+------------------+----+



DataFrame[COUNTRY_FULL: string, avg_mean_wdsp: double, rank: int]

In [None]:
### [POTENTIAL IMPROVEMENT] Impute missing data by cross station regression [Takes long time, not applying this at the moment]
spark.conf.set("spark.sql.pivotMaxValues", "50000")

# Pivot the DataFrame to have a column for TEMP value for each station for each day within a country
pivot_df = station_reading_full.groupBy("COUNTRY_FULL", "YEARMODA").pivot("STN_NO").agg(F.first("TEMP"))

# Placeholder for the results
imputed_rows = []

# Get the list of unique countries
countries = [row['COUNTRY_FULL'] for row in pivot_df.select("COUNTRY_FULL").distinct().collect()]

# Loop through each country and impute missing values for each station using regression
for country in countries:
    group_df = pivot_df.filter(pivot_df["COUNTRY_FULL"] == country)
    stations = [row['STN_NO'] for row in station_reading_full.filter(station_reading_full["COUNTRY_FULL"] == country).select("STN_NO").distinct().collect()]
    for station in stations:
        # If there's a missing TEMP value (9999.9) within the group for the station
        if group_df.filter(group_df[station] == 9999.9).count() > 0:
            
            # Features for regression (all stations except the current one)
            features = [s for s in stations if s != station]

            # Prepare the data
            assembler = VectorAssembler(inputCols=features, outputCol="features")
            assembled_df = assembler.transform(group_df).na.drop(subset=["features"])

            # Separate training and testing data
            train_df = assembled_df.filter(assembled_df[station] != 999.9)
            test_df = assembled_df.filter(assembled_df[station] == 999.9)

            # Linear regression model
            lr = LinearRegression(featuresCol="features", labelCol=station)
            model = lr.fit(train_df)

            # Predict the missing TEMP values
            predictions = model.transform(test_df)
            imputed_rows.extend(predictions.select("COUNTRY_FULL", "YEARMODA", station).collect())

# Convert the imputed rows back to the original format
converted_rows = []

for row in imputed_rows:
    # Extracting station number from column name
    station_no = [col for col in row.asDict().keys() if F.col.isdigit()][0]
    country = row["COUNTRY_FULL"]
    date = row["YEARMODA"]
    temp_value = row[station_no]
    converted_rows.append((country, date, station_no, temp_value))

imputed_df = spark.createDataFrame(converted_rows, ["COUNTRY_FULL", "YEARMODA", "STN_NO", "TEMP"])
imputed_df.show()