## 1. initialize

In [None]:
from pathlib import Path
import os
import getpass
import shutil

from pyspark.sql import SparkSession
from pyspark import SparkConf, SparkContext

In [None]:
os.environ

## 2. start spark cluster

In [None]:
# tweak setting here:
def init_spark(cluster=None, name="dsgrid", tz="UTC"):
    """Initialize a SparkSession."""
    conf = SparkConf().setAppName(name)

    if cluster is None:
        spark = SparkSession.builder.master("local").appName(name).getOrCreate()
    elif cluster == "AWS":
        pass
        # does not need to setMaster for AWS cluster
    else:
        conf = conf.setMaster(cluster)
    conf = conf.setAll(
        [
            #             ("spark.sql.shuffle.partitions", "200"),
            #             ("spark.executor.instances", "7"),
            #             ("spark.executor.cores", "5"),
            #             ("spark.executor.memory", "10g"),
            #             ("spark.driver.memory", "10g"),
            #             ("spark.dynamicAllocation.enabled", True),
            #             ("spark.shuffle.service.enabled", True),
            ("spark.sql.session.timeZone", tz),
        ]
    )
    spark = SparkSession.builder.config(conf=conf).getOrCreate()
    return spark

To launch a standalone cluster or a cluster on Kestrel, follow **instructions** here: \
https://github.com/dsgrid/dsgrid/tree/main/dev#spark-standalone-cluster

accordingly, uncomment and update the cluster name below:

In [None]:
main_tz = "EST"  # <--- UTC, EST

### STAND-ALONE CLUSTER
# cluster = "spark://lliu2-34727s:7077"
# name = "stand-alone"

### CLUSTER ON HPC - Type in nodename
# NODENAME = "r103u23" # <--- change after deploying cluster
# cluster = f"spark://{NODENAME}.ib0.cm.hpc.nrel.gov:7077"
# name = "HPC"

### CLUSTER ON HPC - Get cluster from file dropped by prep_spark_cluster_notebook.py
# import toml
# config = toml.load("cluster.toml")
# cluster = config["cluster"]
# name = "HPC"

### LOCAL MODE
# cluster = None
# name = "local"

### AWS MODE
cluster = "AWS"
name = "AWS"

# Initialize
spark = init_spark(cluster, "dsgrid-load", tz=main_tz)

# get Spark Context UI
sc = spark.sparkContext
sc

#### The *Spark UI* above works only for local mode. For HPC cluster Spark UI, use:
http://localhost:8080

In [None]:
for x in sorted(sc.getConf().getAll()):
    print(x)

## 3. dsgrid

In [None]:
from IPython.core.display import display, HTML

display(HTML("<style>.container { width:100% !important; }</style>"))
import pandas as pd

pd.set_option("display.max_rows", 20)
# import plotly
# pd.options.plotting.backend = "plotly"
import numpy as np
import itertools
import pytz
from datetime import datetime, timedelta

from semver import VersionInfo
from pydantic import ValidationError
import pyspark.sql.functions as F
import pyspark.sql.types as sparktypes

In [None]:
from dsgrid.common import LOCAL_REGISTRY
from dsgrid.registry.registry_manager import RegistryManager
from dsgrid.utils.files import load_data
from dsgrid.utils.spark import create_dataframe, read_dataframe, get_unique_values
from dsgrid.dimension.base_models import DimensionType
from dsgrid.dataset.dataset import Dataset
from dsgrid.project import Project
from dsgrid.dimension.time import TimeZone

## 3.1. Check dsgrid registry

In [None]:
## sync registry and then load offline
# LOCAL_REGISTRY = "s3://nrel-dsgrid-registry-archive"
registry_path = os.getenv("DSGRID_REGISTRY_PATH", default=LOCAL_REGISTRY)
registry_path

In [None]:
sync_and_pull = True  # <--- registry config only
if sync_and_pull:
    print(f"syncing registry: {registry_path}")
    RegistryManager.load(registry_path, offline_mode=False)

In [None]:
# ETH@Review: Were you intending to write something to the right of the arrow?
offline_mode = True  # <---

registry_mgr = RegistryManager.load(registry_path, offline_mode=offline_mode)
project_mgr = registry_mgr.project_manager
dataset_mgr = registry_mgr.dataset_manager
dim_map_mgr = registry_mgr.dimension_mapping_manager
dim_mgr = registry_mgr.dimension_manager
# ETH@Review: This line seems out of place. Or change "Loading" to "Loaded"?
print(f"Loaded dsgrid registry at: {registry_path}")

In [None]:
project_mgr.show(max_width=30, drop_fields=["Date", "Submitter"])

In [None]:
# %%timeit
# ## Dan's test
# from dsgrid.config.time_dimension_base_config import TimeDimensionBaseConfig

# i = 0
# for d_id in registry_mgr.dimension_manager._id_to_type:
#     config = registry_mgr.dimension_manager.get_by_id(d_id)
#     if not isinstance(config, TimeDimensionBaseConfig):
#         config.get_records_dataframe().count()
#         i += 1

# print(i)

## 3.2. Load Project
This section is mostly exploratory (For *Section 4. Queries*, only need to load project) 

####  Some user criteria:
At the projects, I want to be able to:
- Examine what's available in the project:
    * Show project dimensions by type, show resolution by type - I don't care: base/supplemental, mappings, id
    * Get unique records by dimension/resolution
    * Get unique records by selected dimension sets
    * Show mapped dataset
    * Show unit (or select a unit of analysis) and fuel types
- Make queries using:
    * Project dimensions + fuel types + time resolutions
    * Get all types of statistics (max, mean, min, percentiles, count, sum)
    
- dataset level: never mapped, think TEMPO,
- interface to allow for query optimization
    
#### Notes:
 * Project_manager has access to all other managers.
 * Each manager has the responsiblity to retrieve configs
 * Access ConfigModel from configs

In [None]:
# load projct
project_id = "dsgrid_conus_2022"  # <---
project = project_mgr.load_project(project_id)

print("project loaded")

## 3.3. Load Project Datasets

### 3.3.3. TEMPO

load and check tempo dataset here

In [None]:
dataset_id = "tempo_conus_2022"  # <----
project.load_dataset(dataset_id)
tempo = project.get_dataset(dataset_id)
print("tempo dataset loaded")

In [None]:
### TO BE DELETED ###
tempo_load_data_lookup = tempo.load_data_lookup
tempo_load_data = tempo.load_data

# file = "/scratch/dthom/tempo_load_data3.parquet" # <---
# tempo_load_data = spark.read.parquet(file)

In [None]:
tempo_mapped_load_data_lookup = tempo._handler._remap_dimension_columns(tempo_load_data_lookup)
tempo_mapped_load_data = tempo._handler._remap_dimension_columns(tempo_load_data)

In [None]:
del tempo_load_data_lookup
del tempo_load_data

## 4. Queries
### Query util functions

### 4.1. Hourly electricity consumption by *scenario, model_year, and ReEDS PCA*

In [None]:
### all_enduses-totelectric_enduses map

dim_map_id = "conus-2022-detailed-end-uses-kwh__all-electric-end-uses__c4149547-1209-4ce3-bb4c-3ab292067e8a"  # <---
electric_enduses_map = dim_map_mgr.get_by_id(dim_map_id).get_records_dataframe()

### get all project electric end uses
electric_enduses = (
    electric_enduses_map.filter("to_id is not NULL")
    .select("from_id")
    .toPandas()["from_id"]
    .to_list()
)
electric_enduses

In [None]:
### county-to-PCA map
dim_map_id = "us_counties_2020_l48__reeds_pca__fcc554e1-87c9-483f-89e3-a0df9563cf89"  # <---
county_to_pca_map = dim_map_mgr.get_by_id(dim_map_id).get_records_dataframe()
county_to_pca_map.show()

### 4.1.3. TEMPO
query TEMPO data here

In [None]:
## Load timezone map (not registered)
timezone_file = "s3://nrel-dsgrid-int-scratch/scratch-lliu2/county_fip_to_local_prevailing_time.csv"  # "/scratch/lliu2/project_county_timezone/county_fip_to_local_prevailing_time.csv"
tz_map = spark.read.csv(timezone_file, header=True)
tz_map = tz_map.withColumn("from_fraction", F.lit(1))
tz_map.show()

In [None]:
### get electric end uses for transportation
tra_elec_enduses = [col for col in tempo_mapped_load_data.columns if col in electric_enduses]
tra_elec_enduses

In [None]:
### TO BE DELETED
# tempo_mapped_load_data_lookup = tempo_mapped_load_data_lookup.filter("id in ('1621180393', '770011011', '1058530452')")
# tempo_mapped_load_data = tempo_mapped_load_data.filter("id in ('1621180393', '770011011', '1058530452')")

In [None]:
%%time
## 0. consolidate load_data: get total hourly electricity consumption by id
# make get_time_cols accessible at dataset level
tra_elec_kwh = tempo_mapped_load_data.select(
    "id",
    "day_of_week",
    "hour",
    "month",
    sum([F.col(col) for col in tra_elec_enduses]).alias("electricity"),
)
# tra_elec_kwh.show()

In [None]:
%%time
## 1. map load_data_lookup to timezone
load_data_lookup = (
    tempo_mapped_load_data_lookup.filter("id is not NULL")
    .select("sector", "scenario", "model_year", "geography", "id", "fraction")
    .join(
        tz_map,
        on=F.col("geography") == tz_map.from_id,
        how="left",
    )
    .drop("from_id")
    .withColumnRenamed("to_id", "timezone")
)

## combine fraction
nonfraction_cols = [x for x in load_data_lookup.columns if x not in {"fraction", "from_fraction"}]
load_data_lookup = load_data_lookup.fillna(1, subset=["from_fraction"]).selectExpr(
    *nonfraction_cols, "fraction*from_fraction AS fraction"
)
# load_data_lookup.show()

In [None]:
%%time
## 2. join load_data and lookup
tra_elec_kwh = load_data_lookup.join(
    tra_elec_kwh,
    on="id",
    how="left",
).drop("id")

tra_elec_kwh = tra_elec_kwh.groupBy(
    "sector",
    "scenario",
    "geography",
    "model_year",
    "timezone",
    "day_of_week",
    "month",
    "hour",
).agg(F.sum(F.col("fraction") * F.col("electricity")).alias("electricity"))

## cache df
# tra_elec_kwh = tra_elec_kwh.cache()
# tra_elec_kwh.show()

In [None]:
%%time
year = 2012  # <--- weather year
sys_tz = TimeZone.EST.tz
timezones_local = [TimeZone.EPT, TimeZone.CPT, TimeZone.MPT, TimeZone.PPT]

## 3. create range of model_year
model_time_pd = []
for tz in timezones_local:
    model_time_df = pd.DataFrame()
    # create time range in local time
    model_time_df["timestamp"] = pd.date_range(
        start=datetime(year=int(year), month=1, day=1, hour=0),
        end=datetime(year=int(year), month=12, day=31, hour=23),
        tz=tz.tz,
        freq="H",
    )
    model_time_df["timezone"] = tz.value
    model_time_df["day_of_week"] = model_time_df["timestamp"].dt.day_of_week.astype(str)
    model_time_df["month"] = model_time_df["timestamp"].dt.month.astype(str)
    model_time_df["hour"] = model_time_df["timestamp"].dt.hour.astype(str)

    # convert to main timezone
    model_time_df["timestamp"] = model_time_df["timestamp"].dt.tz_convert(sys_tz)
    # wrap time to year
    model_time_df["timestamp"] = model_time_df["timestamp"].apply(lambda x: x.replace(year=year))

    model_time_pd.append(model_time_df)

model_time_pd = pd.concat(model_time_pd, axis=0, ignore_index=True)
model_time_pd["timestamp"] = (
    model_time_pd["timestamp"].dt.tz_localize(None).astype(str)
)  # conver timestamp to str, this is important!
print(model_time_pd)

# convert to spark df
schema = sparktypes.StructType(
    [
        sparktypes.StructField("timestamp", sparktypes.StringType(), False),
        sparktypes.StructField("timezone", sparktypes.StringType(), False),
        sparktypes.StructField("day_of_week", sparktypes.StringType(), False),
        sparktypes.StructField("month", sparktypes.StringType(), False),
        sparktypes.StructField("hour", sparktypes.StringType(), False),
    ]
)
model_time = spark.createDataFrame(model_time_pd, schema=schema)

## covert timestamp from str to timestamp
model_time = model_time.withColumn(
    "timestamp",
    F.from_unixtime(
        F.unix_timestamp(F.col("timestamp"), "yyyy-MM-dd HH:mm:ss"), "yyyy-MM-dd HH:mm:ss"
    ),
)
model_time = model_time.withColumn("timestamp", F.to_timestamp("timestamp"))
model_time = model_time.cache()

print(model_time.printSchema())
print(model_time.count())
model_time.show()

In [None]:
%%time
## 4. expand to model_years
tra_elec_kwh = model_time.join(
    tra_elec_kwh, on=["timezone", "day_of_week", "month", "hour"], how="right"
).drop("day_of_week", "month", "hour")

## cache df
# tra_elec_kwh = tra_elec_kwh.cache()
# tra_elec_kwh.show()

In [None]:
%%time
# 5. map load_data_lookup to PCA
tra_elec_kwh = (
    tra_elec_kwh.join(
        county_to_pca_map, on=F.col("geography") == county_to_pca_map.from_id, how="left"
    )
    .drop("from_id")
    .drop("geography")
    .withColumnRenamed("to_id", "geography")
    .groupBy("sector", "scenario", "geography", "model_year", "timestamp")
    .agg(F.sum("electricity").alias("electricity"))
)

# tra_elec_kwh.show()

In [None]:
%%time
### 6. save as partitions
tra_output_file = "s3://nrel-dsgrid-int-scratch/scratch-lliu2/tempo_projections.parquet"  # Path(f"/scratch/{getpass.getuser()}/tempo_projections.parquet")

# # refresh file dir
if Path(tra_output_file).exists():
    shutil.rmtree(tra_output_file)

if Path(tra_output_file).exists():
    raise ValueError(
        f"file: {tra_output_file} already exist. `shutile.rmtree(tra_output_file)` to override."
    )

tra_elec_kwh.sort("scenario", "model_year", "geography", "timestamp").repartition(
    "scenario", "model_year"
).write.partitionBy("scenario", "model_year").option("path", tra_output_file).saveAsTable(
    "tra_elec_kwh", format="parquet"
)

print("tra_elec_kwh saved")

In [None]:
# %%time
# ########## load transportation projection data ###########
# tra_output_file = "s3://nrel-dsgrid-int-scratch/scratch-lliu2/tempo_projections.parquet" #Path(f"/scratch/{getpass.getuser()}/tempo_projections.parquet")

# if Path(tra_output_file).exists():
#     tra_elec_kwh = read_dataframe(tra_output_file)
#     print("tra_elec_kwh loaded")
# else:
#     print(f"tra_output_file={tra_output_file} does not exist")

In [None]:
%%time
ts = tra_elec_kwh.groupBy("timestamp").count().orderBy("timestamp").toPandas()
ts