## 1. initialize

In [1]:
from pathlib import Path
import os
import getpass
import shutil

from pyspark.sql import SparkSession
from pyspark import SparkConf, SparkContext

## 2. start spark cluster

In [2]:
# tweak setting here:
def init_spark(cluster=None, name="dsgrid", tz="UTC"):
    """Initialize a SparkSession."""
    if cluster is None:
        spark = SparkSession.builder.master("local").appName(name).getOrCreate()
    else:
        conf = SparkConf().setAppName(name).setMaster(cluster)
        conf = conf.setAll([
    #         ("spark.sql.shuffle.partitions", "200"),
    #         ("spark.executor.instances", "7"),
    #         ("spark.executor.cores", "5"),
    #         ("spark.executor.memory", "10g"),
    #         ("spark.driver.memory", "10g"),
#             ("spark.dynamicAllocation.enabled", True),
#             ("spark.shuffle.service.enabled", True),
            ("spark.sql.session.timeZone", tz),
        ])
        spark = (
                SparkSession.builder.config(conf=conf)
                .getOrCreate()
            )
    return spark

To launch a standalone cluster or a cluster on Eagle, follow **instructions** here: \
https://github.com/dsgrid/dsgrid/tree/main/dev#spark-standalone-cluster

accordingly, uncomment and update the cluster name below:

In [3]:
main_tz = "EST" # <--- UTC, EST


### STAND-ALONE CLUSTER
# cluster = "spark://lliu2-34727s:7077" 

### CLUSTER ON HPC - Type in nodename
# NODENAME = "r103u23" # <--- change after deploying cluster
# cluster = f"spark://{NODENAME}.ib0.cm.hpc.nrel.gov:7077" 

# ETH@Review: Adding this section of code because I started the notebook 
# up per the instructions I made before
### CLUSTER ON HPC - Get cluster from file dropped by prep_spark_cluster_notebook.py
# import toml
# config = toml.load("cluster.toml")
# cluster = config["cluster"]

### LOCAL MODE
# cluster = None

### AWS EMR cluster MODE
NODENAME = "172.18.27.8" #"ec2-54-212-53-141.us-west-2.compute.amazonaws.com"
cluster = f"spark://{NODENAME}:7077" 

# Initialize
spark = init_spark(cluster, 'dsgrid-load', tz=main_tz)

# get Spark Context UI
sc = spark.sparkContext
sc # ETH@Review: The link that this prints doesn't work from Eagle

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/06/08 18:38:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
22/06/08 18:38:08 WARN StandaloneAppClient$ClientEndpoint: Failed to connect to master 172.18.27.8:7077
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:301)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:101)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:109)
	at org.apache.spark.deploy.client.StandaloneAppClient$ClientEndpoint$$anon$1.run(StandaloneAppClient.scala:107)
	at java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:511)
	at j

22/06/08 18:39:08 ERROR StandaloneSchedulerBackend: Application has been killed. Reason: All masters are unresponsive! Giving up.
22/06/08 18:39:08 WARN StandaloneSchedulerBackend: Application ID is not initialized yet.
22/06/08 18:39:08 WARN StandaloneAppClient$ClientEndpoint: Drop UnregisterApplication(null) because has not yet connected to master
22/06/08 18:39:08 ERROR SparkContext: Error initializing SparkContext.
java.lang.IllegalArgumentException: requirement failed: Can only call getServletHandlers on a running MetricsSystem
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.metrics.MetricsSystem.getServletHandlers(MetricsSystem.scala:89)
	at org.apache.spark.SparkContext.<init>(SparkContext.scala:603)
	at org.apache.spark.api.java.JavaSparkContext.<init>(JavaSparkContext.scala:58)
	at sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method)
	at sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62)
	at sun.reflec

Py4JJavaError: An error occurred while calling None.org.apache.spark.api.java.JavaSparkContext.
: java.lang.IllegalArgumentException: requirement failed: Can only call getServletHandlers on a running MetricsSystem
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.metrics.MetricsSystem.getServletHandlers(MetricsSystem.scala:89)
	at org.apache.spark.SparkContext.<init>(SparkContext.scala:603)
	at org.apache.spark.api.java.JavaSparkContext.<init>(JavaSparkContext.scala:58)
	at sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method)
	at sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62)
	at sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45)
	at java.lang.reflect.Constructor.newInstance(Constructor.java:423)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:247)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:238)
	at py4j.commands.ConstructorCommand.invokeConstructor(ConstructorCommand.java:80)
	at py4j.commands.ConstructorCommand.execute(ConstructorCommand.java:69)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Thread.java:750)


#### The *Spark UI* above works only for local mode. For HPC cluster Spark UI, use:
http://localhost:8080

In [None]:
sc.getConf().getAll()

## 3. dsgrid

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
import pandas as pd
pd.set_option('display.max_rows', 20)
import plotly
pd.options.plotting.backend = "plotly"
import numpy as np
import itertools
import pytz
from datetime import datetime, timedelta

from semver import VersionInfo
from pydantic import ValidationError
import pyspark.sql.functions as F
import pyspark.sql.types as sparktypes

In [None]:
from dsgrid.common import LOCAL_REGISTRY
from dsgrid.registry.registry_manager import RegistryManager
from dsgrid.utils.files import load_data
from dsgrid.utils.spark import create_dataframe, read_dataframe, get_unique_values
from dsgrid.dimension.base_models import DimensionType
from dsgrid.dataset.dataset import Dataset
from dsgrid.project import Project
from dsgrid.dimension.time import TimeZone

## 3.1. Check dsgrid registry

In [None]:
## sync registry and then load offline
registry_path = os.getenv("DSGRID_REGISTRY_PATH", default=LOCAL_REGISTRY)
registry_path

In [None]:
sync_and_pull = True # <--- registry config only
if sync_and_pull:
    print(f"syncing registry: {registry_path}")
    RegistryManager.load(registry_path, offline_mode=False)

In [None]:
# ETH@Review: Were you intending to write something to the right of the arrow?
offline_mode = True # <---

registry_mgr = RegistryManager.load(registry_path, offline_mode=offline_mode)
project_mgr = registry_mgr.project_manager
dataset_mgr = registry_mgr.dataset_manager
dim_map_mgr = registry_mgr.dimension_mapping_manager
dim_mgr = registry_mgr.dimension_manager
# ETH@Review: This line seems out of place. Or change "Loading" to "Loaded"?
print(f"Loaded dsgrid registry at: {registry_path}")

In [None]:
project_mgr.show(max_width=30, drop_fields=["Date", "Submitter"])

## 3.2. Load Project
This section is mostly exploratory (For *Section 4. Queries*, only need to load project) 

####  Some user criteria:
At the projects, I want to be able to:
- Examine what's available in the project:
    * Show project dimensions by type, show resolution by type - I don't care: base/supplemental, mappings, id
    * Get unique records by dimension/resolution
    * Get unique records by selected dimension sets
    * Show mapped dataset
    * Show unit (or select a unit of analysis) and fuel types
- Make queries using:
    * Project dimensions + fuel types + time resolutions
    * Get all types of statistics (max, mean, min, percentiles, count, sum)
    
- dataset level: never mapped, think TEMPO,
- interface to allow for query optimization
    
#### Notes:
 * Project_manager has access to all other managers.
 * Each manager has the responsiblity to retrieve configs
 * Access ConfigModel from configs

In [None]:
# load projct
project_id = "dsgrid_conus_2022" # <---
project = Project.load(project_id, offline_mode=offline_mode)

print("project loaded")

## 3.3. Load Project Datasets

### 3.3.3. TEMPO

load and check tempo dataset here

In [None]:
dataset_id = "tempo_conus_2022" # <----
project.load_dataset(dataset_id)
tempo = project.get_dataset(dataset_id)
print("tempo dataset loaded")

In [None]:
### TO BE DELETED ###
tempo_load_data_lookup = tempo.load_data_lookup

file = "/scratch/dthom/tempo_load_data3.parquet" # <---
tempo_load_data = spark.read.parquet(file)

In [None]:
tempo_mapped_load_data_lookup = tempo._handler._remap_dimension_columns(tempo_load_data_lookup)
tempo_mapped_load_data = tempo._handler._remap_dimension_columns(tempo_load_data)

In [None]:
del tempo_load_data_lookup
del tempo_load_data

## 4. Queries
### Query util functions

### 4.1. Hourly electricity consumption by *scenario, model_year, and ReEDS PCA*

In [None]:
### all_enduses-totelectric_enduses map

dim_map_id = "conus-2022-detailed-end-uses-kwh__all-electric-end-uses__c4149547-1209-4ce3-bb4c-3ab292067e8a" # <---
electric_enduses_map = dim_map_mgr.get_by_id(dim_map_id).get_records_dataframe()

### get all project electric end uses
electric_enduses = electric_enduses_map.filter("to_id is not NULL").select("from_id").toPandas()["from_id"].to_list()
electric_enduses

In [None]:
### county-to-PCA map
dim_map_id = "us_counties_2020_l48__reeds_pca__fcc554e1-87c9-483f-89e3-a0df9563cf89" # <---
county_to_pca_map = dim_map_mgr.get_by_id(dim_map_id).get_records_dataframe()
county_to_pca_map.show()


### 4.1.3. TEMPO
query TEMPO data here

In [None]:
def make_annual_dataframe(df, data_id, year):
    dfs = df.filter(f"id = {data_id}")
    days_per_month = {
        1: 31,
        2: 28,
        3: 31,
        4: 30,
        5: 31,
        6: 30,
        7: 31,
        8: 31,
        9: 30,
        10: 31,
        11: 30,
        12: 31,
    }
    days_per_year = 366 if year % 4 == 0 else 365 
    hours_per_year = days_per_year * 24
    td = timedelta(hours=1)
    # Set to UTC to avoid DST problems
    start_time = datetime(year=year, month=1, day=1, hour=0, tzinfo=pytz.UTC)
    # TODO: make columns dynamic
    overall = {
        "timestamp": np.array([start_time + i * td for i in range(hours_per_year)]),
        "L1andL2": np.empty(hours_per_year, np.float32),
        "DCFC": np.empty(hours_per_year, np.float32),
        "id": np.empty(hours_per_year, np.int64),
    }
        
    index = 0
    for month in range(1, 13):
        df_month = dfs.filter(f"month = {month}")
        if month == 2 and year % 4 == 0:
            num_days = 29
        else:
            num_days = days_per_month[month]

        df_by_day = {}
        for day in range(1, num_days + 1):
            day_of_week = datetime(year=year, month=month, day=day).weekday()
            if day_of_week not in df_by_day:
                df_by_day[day_of_week] = df_month.filter(f"day_of_week = {day_of_week}").toPandas()
                df_by_day[day_of_week].sort_values(by="hour", inplace=True)
            end = index + 24
            overall["L1andL2"][index:end] = df_by_day[day_of_week]["L1andL2"].values
            overall["L1andL2"][index:end] = df_by_day[day_of_week]["L1andL2"].values
            overall["DCFC"][index:end] = df_by_day[day_of_week]["DCFC"].values
            overall["id"][index:end] = df_by_day[day_of_week]["id"].values
            index += 24
            
    assert index == hours_per_year, index
    return SparkSession.getActiveSession().createDataFrame(pd.DataFrame(overall))


In [None]:
## Load timezone map (not registered)
timezone_file = "/scratch/lliu2/project_county_timezone/county_fip_to_local_prevailing_time.csv"
tz_map = spark.read.csv(timezone_file, header=True)
tz_map = tz_map.withColumn("from_fraction", F.lit(1))
tz_map.show()

In [None]:
### get electric end uses for transportation
tra_elec_enduses = [col for col in tempo_mapped_load_data.columns if col in electric_enduses]
tra_elec_enduses

In [None]:
### TO BE DELETED
# tempo_mapped_load_data_lookup = tempo_mapped_load_data_lookup.filter("id in ('1621180393', '770011011', '1058530452')")
# tempo_mapped_load_data = tempo_mapped_load_data.filter("id in ('1621180393', '770011011', '1058530452')")

In [None]:
%%time
## 0. consolidate load_data: get total hourly electricity consumption by id
# make get_time_cols accessible at dataset level
tra_elec_kwh = tempo_mapped_load_data.select(
    "id",
    "day_of_week",
    "hour",
    "month",
    sum([F.col(col) for col in tra_elec_enduses]).alias("electricity")
)
# tra_elec_kwh.show()

In [None]:
%%time
## 1. map load_data_lookup to timezone
load_data_lookup = tempo_mapped_load_data_lookup.filter("id is not NULL")\
.select("sector", "scenario", "model_year", "geography", "id", "fraction").join(
    tz_map,
    on = F.col("geography")==tz_map.from_id,
    how = "left",
).drop("from_id").withColumnRenamed("to_id", "timezone")

## combine fraction
nonfraction_cols = [x for x in load_data_lookup.columns if x not in {"fraction", "from_fraction"}]
load_data_lookup = load_data_lookup.fillna(1, subset=["from_fraction"]).selectExpr(
    *nonfraction_cols, "fraction*from_fraction AS fraction"
)
# load_data_lookup.show()

In [None]:
%%time
## 2. join load_data and lookup
tra_elec_kwh = load_data_lookup.join(
    tra_elec_kwh,
    on="id",
    how="left",
).drop("id")

tra_elec_kwh = tra_elec_kwh.groupBy(
    "sector",
    "scenario", 
    "geography",
    "model_year",
    "timezone",
    "day_of_week",
    "month",
    "hour",
).agg(F.sum(
    F.col("fraction")*F.col("electricity")
).alias("electricity")
    )

## cache df
# tra_elec_kwh = tra_elec_kwh.cache()
# tra_elec_kwh.show()

In [None]:
%%time
year = 2012 # <--- weather year
sys_tz = TimeZone.EST.tz
timezones_local = [TimeZone.EPT, TimeZone.CPT, TimeZone.MPT, TimeZone.PPT]

## 3. create range of model_year
model_time = []
for tz_local in timezones_local:
    model_time_df = pd.DataFrame()
    # create time range in local time
    model_time_df["timestamp"] = pd.date_range(
        start=datetime(year=int(year), month=1, day=1, hour=0),
        end=datetime(year=int(year), month=12, day=31, hour=23),
        tz=tz_local,
        freq="H")
    model_time_df["timezone"] = tz_local.value
    model_time_df["day_of_week"] = model_time_df["timestamp"].dt.day_of_week.astype(str)
    model_time_df["month"] = model_time_df["timestamp"].dt.month.astype(str)
    model_time_df["hour"] = model_time_df["timestamp"].dt.hour.astype(str)
    
    # convert to main timezone
    model_time_df["timestamp"] = model_time_df["timestamp"].dt.tz_convert(sys_tz)
    # wrap time to year
    model_time_df["timestamp"] = model_time_df["timestamp"].apply(lambda x: x.replace(year=year))
    
    model_time.append(model_time_df)
    
model_time = pd.concat(model_time, axis=0, ignore_index=True)
print(model_time)

# convert to spark df
schema = sparktypes.StructType([
    sparktypes.StructField("timestamp", sparktypes.TimestampType(), False), \
    sparktypes.StructField("timezone", sparktypes.StringType(), False), \
    sparktypes.StructField("day_of_week", sparktypes.StringType(), False), \
    sparktypes.StructField("month", sparktypes.StringType(), False), \
    sparktypes.StructField("hour", sparktypes.StringType(), False), \
])
model_time = SparkSession.getActiveSession().createDataFrame(model_time, schema=schema)
print(model_time.count())
model_time.show()

In [None]:
%%time
## 4. expand to model_years
tra_elec_kwh = model_time.join(
    tra_elec_kwh,
    on=["timezone", "day_of_week", "month", "hour"], 
    how="right"
).drop("day_of_week", "month", "hour")

## cache df
# tra_elec_kwh = tra_elec_kwh.cache()
# tra_elec_kwh.show()

In [None]:
%%time
# 5. map load_data_lookup to PCA
tra_elec_kwh = tra_elec_kwh.join(
    county_to_pca_map,
    on = F.col("geography")==county_to_pca_map.from_id,
    how = "left").drop("from_id").drop("geography").withColumnRenamed("to_id", "geography").groupBy(
    "sector",
    "scenario", 
    "geography",
    "model_year",
    "timestamp"
).agg(F.sum("electricity").alias("electricity"))

# tra_elec_kwh.show()

In [None]:
%%time
### 6. save as partitions
tra_output_file = Path(f"/scratch/{getpass.getuser()}/tempo_projections.parquet")

# # refresh file dir
if tra_output_file.exists():
    shutil.rmtree(tra_output_file)

if tra_output_file.exists():
    raise ValueError(f"file: {tra_output_file} already exist. `shutile.rmtree(tra_output_file)` to override.")

tra_elec_kwh.sort("scenario", "model_year", "geography", "timestamp")\
    .repartition("scenario", "model_year").write\
    .partitionBy("scenario", "model_year")\
    .option("path", tra_output_file)\
    .saveAsTable("tra_elec_kwh", format='parquet')

print("tra_elec_kwh saved")

In [None]:
%%time
########## load transportation projection data ###########
tra_output_file = Path(f"/scratch/{getpass.getuser()}/tempo_projections.parquet")

if tra_output_file.exists():
    tra_elec_kwh = read_dataframe(tra_output_file)
    print("tra_elec_kwh loaded")
else:
    print(f"tra_output_file={tra_output_file} does not exist")