In [None]:
from pyspark.sql import SparkSession

# Spark session & context
spark = SparkSession.builder \
    .appName("world-energy-stats") \
    .master("spark://spark-master:7077")\
    .config("hive.metastore.uris", "thrift://hive-metastore:9083") \
    .enableHiveSupport() \
    .getOrCreate()

#Local Development
# spark = SparkSession.builder.appName("world-energy-stats").master("local").getOrCreate()

In [None]:
df = (spark.read
  .format("csv")
  .option("header", "true")
  .option("inferSchema", "true")
   .load("hdfs://namenode:9000/energy-data/owid-energy-data.csv"))

#Local Development
# df = spark.read.csv("owid-energy-data.csv", header=True, inferSchema=True)

In [None]:
#DROPPING REGIONS (FOR NOW)
df = df.filter(df['iso_code'].isNotNull())

In [None]:
df = df[df['year'] >=1990]

# Drop 2022 as well.
# df = df[df['year'] >=1990]

#40 years of data
grouped_df = df.groupBy("year").count().orderBy("year")
grouped_df.show(40)

In [None]:
# Dropping irrelevant columns
cols_to_drop = [col for col in df.columns if '_per_gdp' if '_per_capita' in col or '_change_pct' in col or '_change_twh' in col]
df = df.drop(*cols_to_drop)
# per_capita_electricity

# Show the updated DataFrame
df.head(n=1)

In [None]:
from pyspark.sql import Window
from pyspark.sql.functions import last, first

temp_column = [column for column in df.columns if 'year' not in column]
temp_column = [column for column in temp_column if 'country' not in column]
temp_column

import pyspark.sql.functions as F

# Define the windows for forward fill and backward fill
ffill_window = "(partition by country order by year rows between unbounded preceding and current row)"
# bfill_window = "(partition by country order by year rows between current row and unbounded following)"

for col in temp_column:
    df = (df.withColumn(col, F.expr(f"case when isnan({col}) then null else {col} end"))
    .withColumn(col, F.expr(f"coalesce({col}, last({col}, true) over {ffill_window})")))
    # .withColumn(col, F.expr(f"coalesce({col}, first({col}, true) over {bfill_window})")))
    

In [None]:
### LEVEL 1 CATEGORIZATION FOR BACKFILLING AND LOGICAL SEPARATION

# Primary Key Columns
primary_keys = ['country', 'year', 'iso_code']

# 1. General Information
df_general = df[primary_keys + ['population', 'gdp', 'electricity_demand', 'electricity_generation', 'primary_energy_consumption']]

# 2. Biofuel
df_biofuel = df[primary_keys + ['biofuel_consumption', 'biofuel_electricity', 'biofuel_share_elec', 'biofuel_share_energy']]

# 3. Coal
df_coal = df[primary_keys + ['coal_consumption', 'coal_electricity', 'coal_production', 'coal_share_elec', 'coal_share_energy']]

# 4. Gas
df_gas = df[primary_keys + ['gas_consumption', 'gas_electricity', 'gas_production', 'gas_share_elec', 'gas_share_energy']]

# 5. Oil
df_oil = df[primary_keys + ['oil_consumption', 'oil_electricity', 'oil_production', 'oil_share_elec', 'oil_share_energy']]

# 6. Fossil Fuels (Aggregate)
df_fossil = df[primary_keys + ['fossil_electricity', 'fossil_fuel_consumption', 'fossil_share_elec', 'fossil_share_energy', 'carbon_intensity_elec']]

# 7. Greenhouse Gas
df_greenhouse_gas = df[primary_keys + ['greenhouse_gas_emissions']]

# 8. Hydro
df_hydro = df[primary_keys + ['hydro_consumption', 'hydro_electricity', 'hydro_share_elec', 'hydro_share_energy']]

# 9. Nuclear
df_nuclear = df[primary_keys + ['nuclear_consumption', 'nuclear_electricity', 'nuclear_share_elec', 'nuclear_share_energy']]

# 10. Renewables (Aggregate)
df_renewables = df[primary_keys + ['renewables_consumption', 'renewables_electricity', 'renewables_share_elec', 'renewables_share_energy']]

# 11. Solar
df_solar = df[primary_keys + ['solar_consumption', 'solar_electricity', 'solar_share_elec', 'solar_share_energy']]

# 12. Wind
df_wind = df[primary_keys + ['wind_consumption', 'wind_electricity', 'wind_share_elec', 'wind_share_energy']]

# 13. Other Renewables
df_other_renewables = df[primary_keys + ['other_renewable_consumption', 'other_renewable_electricity', 'other_renewable_exc_biofuel_electricity', 'other_renewables_share_elec', 'other_renewables_share_elec_exc_biofuel', 'other_renewables_share_energy']]

# 14. Low Carbon
df_low_carbon = df[primary_keys + ['low_carbon_consumption', 'low_carbon_electricity', 'low_carbon_share_elec', 'low_carbon_share_energy']]

# 15. Electricity Imports
df_electricity_imports = df[primary_keys + ['net_elec_imports', 'net_elec_imports_share_demand']]


In [None]:
from pyspark.sql import functions as F

def filter_df_by_threshold(df, threshold):
    """
    Filter a dataframe based on the threshold of non-null counts in non-primary columns.

    Parameters:
    - df: The input dataframe.
    - threshold: The minimum number of non-null values required across non-primary columns.

    Returns:
    - filtered_df: The filtered dataframe.
    - stats: A dictionary containing statistics about the filtering process.
    """

    # Primary Key Columns
    primary_keys = ['country', 'year', 'iso_code']

    # List of columns to check for null values
    columns_to_check = [col for col in df.columns if col not in primary_keys]

    # Count non-null values across all non-primary columns for each country
    agg_exprs = [F.count(F.when(F.col(c).isNotNull(), 1)).alias(c + '_non_null_count') for c in columns_to_check]
    country_counts = df.groupBy('country').agg(*agg_exprs)

    # Sum the non-null counts across all columns for each country
    total_non_null_counts = sum(F.col(c + '_non_null_count') for c in columns_to_check)
    country_counts = country_counts.withColumn('total_non_null_counts', total_non_null_counts)
    
   # Filter countries based on the threshold
    countries_to_keep_df = country_counts.filter(F.col('total_non_null_counts') > threshold).select('country')

    # Find out the countries that were dropped
    all_countries = df.select('country').distinct()
    dropped_countries_df = all_countries.subtract(countries_to_keep_df)
    dropped_countries = [row['country'] for row in dropped_countries_df.collect()]

    # Join with the original DataFrame to get the filtered data
    filtered_df = df.join(countries_to_keep_df, on='country', how='inner')

    original_row_count = df.count()
    filtered_row_count = filtered_df.count()
    rows_dropped = original_row_count - filtered_row_count

    stats = {
        'Original number of rows': original_row_count,
        'Number of rows after filtering': filtered_row_count,
        'Number of rows dropped': rows_dropped,
        'Dropped countries': dropped_countries
    }
    
    print(stats)
    
    return filtered_df

# # Usage example:
# filtered_df_fossil = filter_df_by_threshold(df_fossil, 5)

def count_nulls_by_country(df):
    """
    Count the number of null values for each country and each column (except 'country').

    Parameters:
    - df: The input dataframe.

    Returns:
    - null_counts_df: A dataframe with the count of null values for each column and country.
    """

    # Generate the aggregation expressions
    agg_exprs = [F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in df.columns if c != 'country']

    # Group by 'country' and aggregate
    null_counts_df = df.groupBy("country").agg(*agg_exprs)

    return null_counts_df

# # Usage example:
# null_counts_fossil = count_nulls_by_country(df_fossil)

# # Show the results
# null_counts_fossil.show(n=300)

from pyspark.sql import functions as F

def filter_rows_by_null_threshold(df):
    """
    Filter rows from a dataframe based on the threshold of null values across non-primary columns.

    Parameters:
    - df: The input dataframe.

    Returns:
    - filtered_df: The filtered dataframe.
    - stats: A dictionary containing statistics about the filtering process.
    """
    
    # Primary Key Columns
    primary_keys = ['country', 'year', 'iso_code']

    # List of columns to check for null values
    columns_to_check = [col for col in df.columns if col not in primary_keys]

    # Set the threshold equal to the number of non-primary key columns
    threshold = len(columns_to_check)

    # Calculate the number of nulls for each row
    null_count = sum(F.when(F.col(c).isNull(), 1).otherwise(0) for c in columns_to_check)

    # Filter rows based on the threshold
    filtered_df = df.filter(null_count < threshold)

    original_row_count = df.count()
    filtered_row_count = filtered_df.count()
    rows_dropped = original_row_count - filtered_row_count

    stats = {
        'Original number of rows': original_row_count,
        'Number of rows after filtering': filtered_row_count,
        'Number of rows dropped': rows_dropped
    }
    
    print(stats)

    return filtered_df

# # Usage example:
# filtered_df_fossil, fossil_stats = filter_rows_by_null_threshold(df_fossil)


In [None]:
# Usage example:
null_counts_fossil = count_nulls_by_country(df_fossil)

# # Show the results
null_counts_fossil.show(n=5)

In [None]:
# Calling the filter function on each dataframe
filtered_df_general = filter_df_by_threshold(df_general, 0)
filtered_df_biofuel = filter_df_by_threshold(df_biofuel, 0)
filtered_df_coal = filter_df_by_threshold(df_coal, 0)
filtered_df_gas = filter_df_by_threshold(df_gas, 0)
filtered_df_oil = filter_df_by_threshold(df_oil, 0)
filtered_df_fossil = filter_df_by_threshold(df_fossil, 0)
filtered_df_greenhouse_gas = filter_df_by_threshold(df_greenhouse_gas, 0)
filtered_df_hydro = filter_df_by_threshold(df_hydro, 0)
filtered_df_nuclear = filter_df_by_threshold(df_nuclear, 0)
filtered_df_renewables = filter_df_by_threshold(df_renewables, 0)
filtered_df_solar = filter_df_by_threshold(df_solar, 0)
filtered_df_wind = filter_df_by_threshold(df_wind, 0)
filtered_df_other_renewables = filter_df_by_threshold(df_other_renewables, 0)
filtered_df_low_carbon = filter_df_by_threshold(df_low_carbon, 0)
filtered_df_electricity_imports = filter_df_by_threshold(df_electricity_imports, 0)

In [None]:
filtered_df_ren = filter_df_by_threshold(df_renewables, 0)

# Assuming df is your DataFrame
null_counts_ren = count_nulls_by_country(filtered_df_ren)

# Show the results
null_counts_ren.show(n=10)

In [None]:
# Define the folder path for saving the CSV files
folder_path = './clean/'

# Define file paths for each dataframe within the "clean" folder
filtered_df_general.toPandas().to_csv(folder_path + 'general.csv', index=False)
filtered_df_biofuel.toPandas().to_csv(folder_path + 'biofuel.csv', index=False)
filtered_df_coal.toPandas().to_csv(folder_path + 'coal.csv', index=False)
filtered_df_gas.toPandas().to_csv(folder_path + 'gas.csv', index=False)
filtered_df_oil.toPandas().to_csv(folder_path + 'oil.csv', index=False)
filtered_df_fossil.toPandas().to_csv(folder_path + 'fossil.csv', index=False)
filtered_df_greenhouse_gas.toPandas().to_csv(folder_path + 'greenhouse_gas.csv', index=False)
filtered_df_hydro.toPandas().to_csv(folder_path + 'hydro.csv', index=False)
filtered_df_nuclear.toPandas().to_csv(folder_path + 'nuclear.csv', index=False)
filtered_df_renewables.toPandas().to_csv(folder_path + 'renewables.csv', index=False)
filtered_df_solar.toPandas().to_csv(folder_path + 'solar.csv', index=False)
filtered_df_wind.toPandas().to_csv(folder_path + 'wind.csv', index=False)
filtered_df_other_renewables.toPandas().to_csv(folder_path + 'other_renewables.csv', index=False)
filtered_df_low_carbon.toPandas().to_csv(folder_path + 'low_carbon.csv', index=False)
filtered_df_electricity_imports.toPandas().to_csv(folder_path + 'electricity_imports.csv', index=False)

In [None]:
# save to hive tables.
filtered_df_general.write.mode("overwrite").saveAsTable("wes.general")
filtered_df_biofuel.write.mode("overwrite").saveAsTable("wes.biofuel")
filtered_df_coal.write.mode("overwrite").saveAsTable("wes.coal")
filtered_df_gas.write.mode("overwrite").saveAsTable("wes.gas")
filtered_df_oil.write.mode("overwrite").saveAsTable("wes.oil")
filtered_df_fossil.write.mode("overwrite").saveAsTable("wes.fossil")
filtered_df_greenhouse_gas.write.mode("overwrite").saveAsTable("wes.greenhouse_gas")
filtered_df_hydro.write.mode("overwrite").saveAsTable("wes.hydro")
filtered_df_nuclear.write.mode("overwrite").saveAsTable("wes.nuclear")
filtered_df_renewables.write.mode("overwrite").saveAsTable("wes.renewables")
filtered_df_solar.write.mode("overwrite").saveAsTable("wes.solar")
filtered_df_wind.write.mode("overwrite").saveAsTable("wes.wind")
filtered_df_other_renewables.write.mode("overwrite").saveAsTable("wes.other_renewables")
filtered_df_low_carbon.write.mode("overwrite").saveAsTable("wes.low_carbon")
filtered_df_electricity_imports.write.mode("overwrite").saveAsTable("wes.electricity_imports")


In [9]:
SELECT * FROM wes.renewables;

SyntaxError: invalid syntax (3884440000.py, line 1)