In [1]:
import datetime as dt
import logging
import json
from pathlib import Path

from pyspark.sql import SparkSession
import pandas as pd
from s3pathlib import S3Path
from pyarrow import fs

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

data_dir = Path("/datasets/dsgrid/dsgrid-tempo-v2022/")
dataset_name = "full_dataset"

spark = (
            SparkSession.builder
            .appName("dsgrid")
            .config("spark.sql.sources.partitionColumnTypeInference.enabled", "false")
            .config("spark.sql.session.timeZone", "EST")
            .getOrCreate()
        )
settings = spark.sparkContext.getConf().getAll()
for item in ["spark.sql.sources.partitionColumnTypeInference.enabled", "spark.sql.session.timeZone"]:
    if item not in [x[0] for x in settings]:
        settings.append((item, spark.conf.get(item)))
settings

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/11/12 17:54:32 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


[('spark.driver.extraJavaOptions',
  '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false'),
 ('spark.driver.host', 'x3000c0s25b0n0.head.cm.kestrel.hpc.nrel.gov'),
 ('spark.app.name', 'dsgrid'),
 

## Partitioned File Utilities

In [2]:
def is_partitioned(filepath):
    for p in filepath.iterdir():
        if p.is_dir() and ("=" in p.stem) and (len(p.stem.split("=")) == 2):
            return True
    return False

def get_partitions(filepath):
    assert is_partitioned(filepath), f"{filepath} is not partitioned"
    
    partition_name = None
    for p in filepath.iterdir():
        if p.is_dir() and ("=" in p.stem):
            tmp, value = p.stem.split("=")
            if partition_name:
                assert (tmp == partition_name), f"Found two different partition names in {filepath}: {partition_name}, {tmp}"
            partition_name = tmp
            yield partition_name, value, p

def print_partitions(filepath, print_depth=2, _depth=0):
    if is_partitioned(filepath):
        space = ' ' * 4 * _depth
        for partition_name, value, p in get_partitions(filepath):
            print(f"{space}{partition_name}={value}")
        if (not print_depth) or ((_depth + 1) < print_depth):
            print_partitions(p, print_depth=print_depth, _depth=_depth+1)

## Load Data

In [3]:
def get_metadata(dataset_path):
    with open(dataset_path / "metadata.json") as f:
        result = json.load(f)
    return result

# load metadata and get column names by type
metadata = get_metadata(data_dir / dataset_name)
assert metadata["table_format"]["format_type"] == "unpivoted", metadata["table_format"]
value_column = metadata["table_format"]["value_column"]
columns_by_type = {dim_type: metadata["dimensions"][dim_type][0]["column_names"][0] 
                   for dim_type in metadata["dimensions"] if metadata["dimensions"][dim_type]}

In [4]:
# Load data table
filepath = data_dir / dataset_name / "table.parquet"
df = spark.read.parquet(str(filepath))
tablename = "tbl"
df.createOrReplaceTempView(tablename)
logger.info(f"Loaded {filepath} as {tablename}:\n{df.printSchema()}")
df.show(n=5)

INFO:__main__:Loaded /datasets/dsgrid/dsgrid-tempo-v2022/full_dataset/table.parquet as tbl:
None


root
 |-- time_est: timestamp (nullable = true)
 |-- end_use: string (nullable = true)
 |-- household_and_vehicle_type: string (nullable = true)
 |-- transportation: string (nullable = true)
 |-- weather_2012: string (nullable = true)
 |-- value: double (nullable = true)
 |-- county: string (nullable = true)
 |-- scenario: string (nullable = true)
 |-- tempo_project_model_years: string (nullable = true)
 |-- state: string (nullable = true)



                                                                                

+-------------------+-------------------+--------------------------+--------------+------------+-----+------+---------+-------------------------+-----+
|           time_est|            end_use|household_and_vehicle_type|transportation|weather_2012|value|county| scenario|tempo_project_model_years|state|
+-------------------+-------------------+--------------------------+--------------+------------+-----+------+---------+-------------------------+-----+
|2012-11-22 00:00:00|electricity_ev_l1l2|      Some_Drivers_Larg...|         trans|        2012|  0.0| 24001|reference|                     2040|   MD|
|2012-07-27 20:00:00|electricity_ev_l1l2|      Some_Drivers_Smal...|         trans|        2012|  0.0| 24001|reference|                     2040|   MD|
|2012-11-22 00:00:00|electricity_ev_dcfc|      Some_Drivers_Larg...|         trans|        2012|  0.0| 24001|reference|                     2040|   MD|
|2012-07-27 20:00:00|electricity_ev_dcfc|      Some_Drivers_Smal...|         trans|     

In [5]:
print(f"Dataset contains {df.count():,} data points")



Dataset contains 264,288,771,856 data points


                                                                                

## Recreate Lefthand Side of Figure ES-1

We recommend using a Spark cluster for this query. <span style="color:red">
Local mode is expected to fail</span> for all but the smallest datasets. 

In [6]:
df = spark.sql(f"""SELECT scenario, 
                          {columns_by_type["model_year"]} as year, 
                          SUM({value_column})/1.0E6 as annual_twh
                     FROM {tablename} 
                 GROUP BY scenario, {columns_by_type["model_year"]}
                 ORDER BY scenario, year""").toPandas()
df["scenario"] = df["scenario"].map({
    "efs_high_ldv": "EFS High Electrification",
    "ldv_sales_evs_2035": "All LDV Sales EV by 2035",
    "reference": "AEO Reference"
})

24/11/12 17:56:56 ERROR Executor: Exception in task 48.0 in stage 47.0 (TID 10542)
java.lang.OutOfMemoryError: Java heap space
24/11/12 17:56:56 ERROR Executor: Exception in task 46.0 in stage 47.0 (TID 10540)
java.lang.OutOfMemoryError: Java heap space
24/11/12 17:56:56 ERROR Executor: Exception in task 92.0 in stage 47.0 (TID 10586)
java.lang.OutOfMemoryError: Java heap space
24/11/12 17:56:56 ERROR Executor: Exception in task 14.0 in stage 47.0 (TID 10508)
java.lang.OutOfMemoryError: GC overhead limit exceeded
24/11/12 17:56:56 ERROR Executor: Exception in task 15.0 in stage 47.0 (TID 10509)
java.lang.OutOfMemoryError: Java heap space
24/11/12 17:56:56 ERROR Executor: Exception in task 54.0 in stage 47.0 (TID 10548)
java.lang.OutOfMemoryError: Java heap space
24/11/12 17:56:56 ERROR Executor: Exception in task 91.0 in stage 47.0 (TID 10585)
java.lang.OutOfMemoryError: Java heap space
24/11/12 17:56:56 ERROR Executor: Exception in task 41.0 in stage 47.0 (TID 10535)
java.lang.OutOfMe

In [None]:
import plotly.express as px

fig = px.line(df, x="year", y="annual_twh", color="scenario", 
              labels={"annual_twh": "EV Load (TWh/yr)", "scenario": "Scenario"}, 
              range_y=[-25,1025],
              width=600, height=450, template="plotly_white")
fig

## Verify Timestamps Are As Expected

Timestamps show up as expected because of setting the spark.sql.session.timeZone configuration to "EST" (in the first cell).

In [None]:
assert columns_by_type["time"] == "time_est", "Code in this section only makes sense if the dataset has timestamps"

In [None]:
# select a subset of the data and look at initial timestamps

where_clause = f"(scenario = 'reference') AND ({columns_by_type['model_year']} = 2050)"

if columns_by_type['geography'] == "census_division":
    where_clause += f" AND ({columns_by_type['geography']} = 'middle_atlantic')"
elif columns_by_type['geography'] == "state":
    where_clause += f" AND ({columns_by_type['geography']} = 'RI')"
elif columns_by_type['geography'] == "county":
    where_clause += f" AND ({columns_by_type['geography']} = '39023')"
else:
    raise NotImplementedError()

if "subsector" not in columns_by_type:
    pass
elif columns_by_type['subsector'] == "subsector":
    where_clause += f" AND ({columns_by_type['subsector']} = 'bev_compact')"
elif columns_by_type['subsector'] == "household_and_vehicle_type":
    where_clause += f" AND ({columns_by_type['subsector']} = 'Some_Drivers_Larger+Low_Income+Second_City+Pickup+BEV_100')"
else:
    raise NotImplementedError()

if columns_by_type['metric'] == "end_uses_by_fuel_type":
    pass
elif columns_by_type['metric'] == "end_use":
    where_clause += f" AND ({columns_by_type['metric']} = 'electricity_ev_l1l2')"   
else:
    raise NotImplementedError()

df = spark.sql(f"SELECT * FROM {tablename} WHERE {where_clause} ORDER BY time_est LIMIT 5")
df.show()

## Verify that Profiles in Different Timezones Are As Expected

In [None]:
assert columns_by_type["time"] == "time_est", "Code in this section only makes sense if the dataset has timestamps"

In [None]:
def get_profile(start_timestamp, end_timestamp, where_clause, 
                tablename=tablename, value_column=value_column, 
                normalize_profile=True, replace_timestamps=True):
    df = spark.sql(f"""SELECT time_est, 
                              SUM({value_column}) as {value_column}
                         FROM {tablename} 
                        WHERE {where_clause} AND 
                              (time_est >= TIMESTAMP '{start_timestamp}') AND 
                              (time_est <= TIMESTAMP '{end_timestamp}')
                     GROUP BY time_est 
                     ORDER BY time_est""").toPandas()
    if normalize_profile:
        df[value_column] = df[value_column] / df[value_column].sum()
    if replace_timestamps:
        df["hour"] = df.index.values
        df = df[["hour",value_column]]
    return df

where_clause = f"(scenario = 'reference') AND ({columns_by_type['model_year']} = 2050)"

if "subsector" not in columns_by_type:
    pass
elif columns_by_type['subsector'] == "subsector":
    where_clause += f" AND ({columns_by_type['subsector']} = 'bev_compact')"
elif columns_by_type['subsector'] == "household_and_vehicle_type":
    where_clause += f" AND ({columns_by_type['subsector']} = 'Some_Drivers_Smaller+Middle_Income+Suburban+SUV+BEV_300')"
else:
    raise NotImplementedError()

geographies = None
if columns_by_type['geography'] == "census_division":
    geographies = {"ET": "middle_atlantic", "CT": "west_south_central", "MT": "mountain", "PT": "pacific"}
elif columns_by_type['geography'] == "state":
    geographies = {"ET": "NC", "CT": "TX", "MT": "CO", "PT": "OR"}
elif columns_by_type['geography'] == "county":
    geographies = {"ET": "37183", "CT": "48453", "MT": "08069", "PT": "06059"}
else:
    raise NotImplementedError()

days = {
    "Standard Time": (dt.datetime(2012, 2, 14, 0), dt.datetime(2012, 2, 14, 23)),
    "Daylight Savings Time": (dt.datetime(2012, 8, 14, 0), dt.datetime(2012, 8, 14, 23))
}

data = []
for time_type, time_tuple in days.items():
    for tz, geo in geographies.items():
        data.append(get_profile(time_tuple[0], time_tuple[1], where_clause + f" AND {columns_by_type['geography']} = '{geo}'"))
        data[-1]["Time Type"] = time_type
        data[-1]["Time Zone"] = tz
df = pd.concat(data)
df

In [None]:
import plotly.express as px

fig = px.line(df, x="hour", y=value_column, color="Time Zone", line_dash="Time Type",
              color_discrete_map={"ET": "red", "CT": "orange", "MT": "blue", "PT": "purple"},
              labels={"value": "Normalized Load Profile", "hour": "Hour of EST Day"},
              #range_y=[0,0.1],
              width=600, template="plotly_white")
fig

## Demonstrate Loading a Subset of a Larger Dataset

In [None]:
filepath = data_dir / dataset_name / "table.parquet"
print_partitions(filepath, print_depth=None)
print(filepath)

In [None]:
# Edit this list of tuples as desired
partitions=[
    ("scenario", "ldv_sales_evs_2035"),
    ("tempo_project_model_years", "2040"),
    ("state", "VT")
]

subset_filepath = filepath
for partition_name, value in partitions:
    subset_filepath = subset_filepath / f"{partition_name}={value}"

# Load partial data table
df = spark.read.parquet(str(subset_filepath))
tablename = "tbl"
df.createOrReplaceTempView(tablename)
df.show(n=5)

In [None]:
print(f"Partial dataset contains {df.count():,} data points")