In [5]:
import datetime as dt
import logging
import json
from pathlib import Path

from pyspark.sql import SparkSession
import pandas as pd

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

data_dir = Path("/datasets/dsgrid-tempo-v2022/")
#dataset_name = "full_dataset"
dataset_name = "simple_profiles"

spark = (
            SparkSession.builder
            .appName("dsgrid")
            .config("spark.sql.sources.partitionColumnTypeInference.enabled", "false")
            .config("spark.sql.session.timeZone", "EST")
            .getOrCreate()
        )
settings = spark.sparkContext.getConf().getAll()
for item in ["spark.sql.sources.partitionColumnTypeInference.enabled", "spark.sql.session.timeZone"]:
    if item not in [x[0] for x in settings]:
        settings.append((item, spark.conf.get(item)))
settings

[('spark.sql.shuffle.partitions', '103'),
 ('spark.executor.cores', '5'),
 ('spark.executor.id', 'driver'),
 ('spark.driver.maxResultSize', '1g'),
 ('spark.driver.memory', '1g'),
 ('spark.app.name', 'PySparkShell'),
 ('spark.app.id', 'app-20240524205619-0000'),
 ('spark.executor.memory', '12g'),
 ('spark.driver.extraJavaOptions',
  '-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=jav

## Partitioned File Utilities

In [6]:
def is_partitioned(filepath):
    for p in filepath.iterdir():
        if p.is_dir() and ("=" in p.stem) and (len(p.stem.split("=")) == 2):
            return True
    return False

def get_partitions(filepath):
    assert is_partitioned(filepath), f"{filepath} is not partitioned"
    
    partition_name = None
    for p in filepath.iterdir():
        if p.is_dir() and ("=" in p.stem):
            tmp, value = p.stem.split("=")
            if partition_name:
                assert (tmp == partition_name), f"Found two different partition names in {filepath}: {partition_name}, {tmp}"
            partition_name = tmp
            yield partition_name, value, p

def print_partitions(filepath, print_depth=2, _depth=0):
    if is_partitioned(filepath):
        space = ' ' * 4 * _depth
        for partition_name, value, p in get_partitions(filepath):
            print(f"{space}{partition_name}={value}")
        if (not print_depth) or ((_depth + 1) < print_depth):
            print_partitions(p, print_depth=print_depth, _depth=_depth+1)

## Load Data

In [7]:
def get_metadata(dataset_path):
    with open(dataset_path / "metadata.json") as f:
        result = json.load(f)
    return result

# load metadata and get column names by type
metadata = get_metadata(data_dir / dataset_name)
assert metadata["table_format"]["format_type"] == "unpivoted", metadata["table_format"]
value_column = metadata["table_format"]["value_column"]
columns_by_type = {dim_type: metadata["dimensions"][dim_type][0]["column_names"][0] 
                   for dim_type in metadata["dimensions"] if metadata["dimensions"][dim_type]}

In [8]:
# Load data table
filepath = data_dir / dataset_name / "table.parquet"
df = spark.read.parquet(str(filepath))
tablename = "tbl"
df.createOrReplaceTempView(tablename)
logger.info(f"Loaded {filepath} as {tablename}:\n{df.printSchema()}")
df.show(n=5)

24/05/24 14:58:49 WARN TaskSetManager: Lost task 0.0 in stage 1.0 (TID 4) (10.150.14.89 executor 2): org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.util.ThreadUtils$.parmap(ThreadUtils.scala:383)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$.readParquetFootersInParallel(ParquetFileFormat.scala:443)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$.$anonfun$mergeSchemasInParallel$1(ParquetFileFormat.scala:493)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$.$anonfun$mergeSchemasInParallel$1$adapted(ParquetFileFormat.scala:485)
	at org.apache.spark.sql.execution.datasources.SchemaMergeUtils$.$anonfun$mergeSchemasInParallel$2(SchemaMergeUtils.scala:79)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:85

Py4JJavaError: An error occurred while calling o191.parquet.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 1.0 failed 4 times, most recent failure: Lost task 0.3 in stage 1.0 (TID 7) (10.150.14.89 executor 10): org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.util.ThreadUtils$.parmap(ThreadUtils.scala:383)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$.readParquetFootersInParallel(ParquetFileFormat.scala:443)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$.$anonfun$mergeSchemasInParallel$1(ParquetFileFormat.scala:493)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$.$anonfun$mergeSchemasInParallel$1$adapted(ParquetFileFormat.scala:485)
	at org.apache.spark.sql.execution.datasources.SchemaMergeUtils$.$anonfun$mergeSchemasInParallel$2(SchemaMergeUtils.scala:79)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:855)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:855)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:833)
Caused by: java.io.FileNotFoundException: File file:/datasets/dsgrid-tempo-v2022/simple_profiles/table.parquet/part-00000-118a325e-25dd-4ef7-b470-495add8be680-c000.snappy.parquet does not exist
	at org.apache.hadoop.fs.RawLocalFileSystem.deprecatedGetFileStatus(RawLocalFileSystem.java:779)
	at org.apache.hadoop.fs.RawLocalFileSystem.getFileLinkStatusInternal(RawLocalFileSystem.java:1100)
	at org.apache.hadoop.fs.RawLocalFileSystem.getFileStatus(RawLocalFileSystem.java:769)
	at org.apache.hadoop.fs.FilterFileSystem.getFileStatus(FilterFileSystem.java:462)
	at org.apache.hadoop.fs.ChecksumFileSystem$ChecksumFSInputChecker.<init>(ChecksumFileSystem.java:160)
	at org.apache.hadoop.fs.ChecksumFileSystem.open(ChecksumFileSystem.java:372)
	at org.apache.hadoop.fs.FileSystem.open(FileSystem.java:976)
	at org.apache.parquet.hadoop.util.HadoopInputFile.newStream(HadoopInputFile.java:69)
	at org.apache.parquet.hadoop.ParquetFileReader.<init>(ParquetFileReader.java:796)
	at org.apache.parquet.hadoop.ParquetFileReader.open(ParquetFileReader.java:666)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFooterReader.readFooter(ParquetFooterReader.java:85)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFooterReader.readFooter(ParquetFooterReader.java:76)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$.$anonfun$readParquetFootersInParallel$1(ParquetFileFormat.scala:450)
	at org.apache.spark.util.ThreadUtils$.$anonfun$parmap$2(ThreadUtils.scala:380)
	at scala.concurrent.Future$.$anonfun$apply$1(Future.scala:659)
	at scala.util.Success.$anonfun$map$1(Try.scala:255)
	at scala.util.Success.map(Try.scala:213)
	at scala.concurrent.Future.$anonfun$map$1(Future.scala:292)
	at scala.concurrent.impl.Promise.liftedTree1$1(Promise.scala:33)
	at scala.concurrent.impl.Promise.$anonfun$transform$1(Promise.scala:33)
	at scala.concurrent.impl.CallbackRunnable.run(Promise.scala:64)
	at java.base/java.util.concurrent.ForkJoinTask$RunnableExecuteAction.exec(ForkJoinTask.java:1395)
	at java.base/java.util.concurrent.ForkJoinTask.doExec(ForkJoinTask.java:373)
	at java.base/java.util.concurrent.ForkJoinPool$WorkQueue.topLevelExec(ForkJoinPool.java:1182)
	at java.base/java.util.concurrent.ForkJoinPool.scan(ForkJoinPool.java:1655)
	at java.base/java.util.concurrent.ForkJoinPool.runWorker(ForkJoinPool.java:1622)
	at java.base/java.util.concurrent.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:165)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2844)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2780)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2779)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2779)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1242)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1242)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1242)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3048)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2982)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2971)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:984)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2398)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2419)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2438)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2463)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1046)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:407)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1045)
	at org.apache.spark.sql.execution.datasources.SchemaMergeUtils$.mergeSchemasInParallel(SchemaMergeUtils.scala:73)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$.mergeSchemasInParallel(ParquetFileFormat.scala:497)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetUtils$.inferSchema(ParquetUtils.scala:132)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat.inferSchema(ParquetFileFormat.scala:79)
	at org.apache.spark.sql.execution.datasources.DataSource.$anonfun$getOrInferFileFormatSchema$11(DataSource.scala:208)
	at scala.Option.orElse(Option.scala:447)
	at org.apache.spark.sql.execution.datasources.DataSource.getOrInferFileFormatSchema(DataSource.scala:205)
	at org.apache.spark.sql.execution.datasources.DataSource.resolveRelation(DataSource.scala:407)
	at org.apache.spark.sql.DataFrameReader.loadV1Source(DataFrameReader.scala:229)
	at org.apache.spark.sql.DataFrameReader.$anonfun$load$2(DataFrameReader.scala:211)
	at scala.Option.getOrElse(Option.scala:189)
	at org.apache.spark.sql.DataFrameReader.load(DataFrameReader.scala:211)
	at org.apache.spark.sql.DataFrameReader.parquet(DataFrameReader.scala:563)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Thread.java:750)
Caused by: org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.util.ThreadUtils$.parmap(ThreadUtils.scala:383)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$.readParquetFootersInParallel(ParquetFileFormat.scala:443)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$.$anonfun$mergeSchemasInParallel$1(ParquetFileFormat.scala:493)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$.$anonfun$mergeSchemasInParallel$1$adapted(ParquetFileFormat.scala:485)
	at org.apache.spark.sql.execution.datasources.SchemaMergeUtils$.$anonfun$mergeSchemasInParallel$2(SchemaMergeUtils.scala:79)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:855)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:855)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.lang.Thread.run(Thread.java:833)
Caused by: java.io.FileNotFoundException: File file:/datasets/dsgrid-tempo-v2022/simple_profiles/table.parquet/part-00000-118a325e-25dd-4ef7-b470-495add8be680-c000.snappy.parquet does not exist
	at org.apache.hadoop.fs.RawLocalFileSystem.deprecatedGetFileStatus(RawLocalFileSystem.java:779)
	at org.apache.hadoop.fs.RawLocalFileSystem.getFileLinkStatusInternal(RawLocalFileSystem.java:1100)
	at org.apache.hadoop.fs.RawLocalFileSystem.getFileStatus(RawLocalFileSystem.java:769)
	at org.apache.hadoop.fs.FilterFileSystem.getFileStatus(FilterFileSystem.java:462)
	at org.apache.hadoop.fs.ChecksumFileSystem$ChecksumFSInputChecker.<init>(ChecksumFileSystem.java:160)
	at org.apache.hadoop.fs.ChecksumFileSystem.open(ChecksumFileSystem.java:372)
	at org.apache.hadoop.fs.FileSystem.open(FileSystem.java:976)
	at org.apache.parquet.hadoop.util.HadoopInputFile.newStream(HadoopInputFile.java:69)
	at org.apache.parquet.hadoop.ParquetFileReader.<init>(ParquetFileReader.java:796)
	at org.apache.parquet.hadoop.ParquetFileReader.open(ParquetFileReader.java:666)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFooterReader.readFooter(ParquetFooterReader.java:85)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFooterReader.readFooter(ParquetFooterReader.java:76)
	at org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat$.$anonfun$readParquetFootersInParallel$1(ParquetFileFormat.scala:450)
	at org.apache.spark.util.ThreadUtils$.$anonfun$parmap$2(ThreadUtils.scala:380)
	at scala.concurrent.Future$.$anonfun$apply$1(Future.scala:659)
	at scala.util.Success.$anonfun$map$1(Try.scala:255)
	at scala.util.Success.map(Try.scala:213)
	at scala.concurrent.Future.$anonfun$map$1(Future.scala:292)
	at scala.concurrent.impl.Promise.liftedTree1$1(Promise.scala:33)
	at scala.concurrent.impl.Promise.$anonfun$transform$1(Promise.scala:33)
	at scala.concurrent.impl.CallbackRunnable.run(Promise.scala:64)
	at java.util.concurrent.ForkJoinTask$RunnableExecuteAction.exec(ForkJoinTask.java:1395)
	at java.util.concurrent.ForkJoinTask.doExec(ForkJoinTask.java:373)
	at java.util.concurrent.ForkJoinPool$WorkQueue.topLevelExec(ForkJoinPool.java:1182)
	at java.util.concurrent.ForkJoinPool.scan(ForkJoinPool.java:1655)
	at java.util.concurrent.ForkJoinPool.runWorker(ForkJoinPool.java:1622)
	at java.util.concurrent.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:165)


In [None]:
print(f"Dataset contains {df.count():,} data points")

## Recreate Lefthand Side of Figure ES-1

We recommend using a Spark cluster for this query. Local mode is expected to fail for all but the smallest datasets. 

In [None]:
df = spark.sql(f"""SELECT scenario, 
                          {columns_by_type["model_year"]} as year, 
                          SUM({value_column})/1.0E6 as annual_twh
                     FROM {tablename} 
                 GROUP BY scenario, {columns_by_type["model_year"]}
                 ORDER BY scenario, year""").toPandas()
df["scenario"] = df["scenario"].map({
    "efs_high_ldv": "EFS High Electrification",
    "ldv_sales_evs_2035": "All LDV Sales EV by 2035",
    "reference": "AEO Reference"
})

In [None]:
import plotly.express as px

fig = px.line(df, x="year", y="annual_twh", color="scenario", 
              labels={"annual_twh": "EV Load (TWh/yr)", "scenario": "Scenario"}, 
              range_y=[-25,1025],
              width=600, height=450, template="plotly_white")
fig

## Verify Timestamps Are As Expected

Timestamps show up as expected because of setting the spark.sql.session.timeZone configuration to "EST" (in the first cell).

In [None]:
assert columns_by_type["time"] == "time_est", "Code in this section only makes sense if the dataset has timestamps"

In [None]:
# select a subset of the data and look at initial timestamps

where_clause = f"(scenario = 'reference') AND ({columns_by_type['model_year']} = 2050)"

if columns_by_type['geography'] == "census_division":
    where_clause += f" AND ({columns_by_type['geography']} = 'middle_atlantic')"
elif columns_by_type['geography'] == "state":
    where_clause += f" AND ({columns_by_type['geography']} = 'RI')"
elif columns_by_type['geography'] == "county":
    where_clause += f" AND ({columns_by_type['geography']} = '39023')"
else:
    raise NotImplementedError()

if "subsector" not in columns_by_type:
    pass
elif columns_by_type['subsector'] == "subsector":
    where_clause += f" AND ({columns_by_type['subsector']} = 'bev_compact')"
elif columns_by_type['subsector'] == "household_and_vehicle_type":
    where_clause += f" AND ({columns_by_type['subsector']} = 'Some_Drivers_Larger+Low_Income+Second_City+Pickup+BEV_100')"
else:
    raise NotImplementedError()

if columns_by_type['metric'] == "end_uses_by_fuel_type":
    pass
elif columns_by_type['metric'] == "end_use":
    where_clause += f" AND ({columns_by_type['metric']} = 'electricity_ev_l1l2')"   
else:
    raise NotImplementedError()

df = spark.sql(f"SELECT * FROM {tablename} WHERE {where_clause} ORDER BY time_est LIMIT 5")
df.show()

## Verify that Profiles in Different Timezones Are As Expected

In [None]:
assert columns_by_type["time"] == "time_est", "Code in this section only makes sense if the dataset has timestamps"

In [None]:
def get_profile(start_timestamp, end_timestamp, where_clause, 
                tablename=tablename, value_column=value_column, 
                normalize_profile=True, replace_timestamps=True):
    df = spark.sql(f"""SELECT time_est, 
                              SUM({value_column}) as {value_column}
                         FROM {tablename} 
                        WHERE {where_clause} AND 
                              (time_est >= TIMESTAMP '{start_timestamp}') AND 
                              (time_est <= TIMESTAMP '{end_timestamp}')
                     GROUP BY time_est 
                     ORDER BY time_est""").toPandas()
    if normalize_profile:
        df[value_column] = df[value_column] / df[value_column].sum()
    if replace_timestamps:
        df["hour"] = df.index.values
        df = df[["hour",value_column]]
    return df

where_clause = f"(scenario = 'reference') AND ({columns_by_type['model_year']} = 2050)"

if "subsector" not in columns_by_type:
    pass
elif columns_by_type['subsector'] == "subsector":
    where_clause += f" AND ({columns_by_type['subsector']} = 'bev_compact')"
elif columns_by_type['subsector'] == "household_and_vehicle_type":
    where_clause += f" AND ({columns_by_type['subsector']} = 'Some_Drivers_Smaller+Middle_Income+Suburban+SUV+BEV_300')"
else:
    raise NotImplementedError()

geographies = None
if columns_by_type['geography'] == "census_division":
    geographies = {"ET": "middle_atlantic", "CT": "west_south_central", "MT": "mountain", "PT": "pacific"}
elif columns_by_type['geography'] == "state":
    geographies = {"ET": "NC", "CT": "TX", "MT": "CO", "PT": "OR"}
elif columns_by_type['geography'] == "county":
    geographies = {"ET": "37183", "CT": "48453", "MT": "08069", "PT": "06059"}
else:
    raise NotImplementedError()

days = {
    "Standard Time": (dt.datetime(2012, 2, 14, 0), dt.datetime(2012, 2, 14, 23)),
    "Daylight Savings Time": (dt.datetime(2012, 8, 14, 0), dt.datetime(2012, 8, 14, 23))
}

data = []
for time_type, time_tuple in days.items():
    for tz, geo in geographies.items():
        data.append(get_profile(time_tuple[0], time_tuple[1], where_clause + f" AND {columns_by_type['geography']} = '{geo}'"))
        data[-1]["Time Type"] = time_type
        data[-1]["Time Zone"] = tz
df = pd.concat(data)
df

In [None]:
import plotly.express as px

fig = px.line(df, x="hour", y=value_column, color="Time Zone", line_dash="Time Type",
              color_discrete_map={"ET": "red", "CT": "orange", "MT": "blue", "PT": "purple"},
              labels={"value": "Normalized Load Profile", "hour": "Hour of EST Day"},
              #range_y=[0,0.1],
              width=600, template="plotly_white")
fig

## Demonstrate Loading a Subset of a Larger Dataset

In [None]:
filepath = data_dir / dataset_name / "table.parquet"
print_partitions(filepath, print_depth=None)

In [None]:
# Edit this list of tuples as desired
partitions=[
    ("scenario", "ldv_sales_evs_2035"),
    ("tempo_project_model_years", "2040"),
    ("state", "VT")
]

subset_filepath = filepath
for partition_name, value in partitions:
    subset_filepath = subset_filepath / f"{partition_name}={value}"

# Load partial data table
df = spark.read.parquet(str(subset_filepath))
tablename = "tbl"
df.createOrReplaceTempView(tablename)
df.show(n=5)

In [None]:
print(f"Partial dataset contains {df.count():,} data points")