# Exploratory data analysis

In [1]:
import os
import xarray as xr

import matplotlib.pyplot as plt

In [9]:
%cd /g/data/w42/dr6273/work/demand_model/

import functions as fn

/g/data/w42/dr6273/work/demand_model


In [3]:
path = "/g/data/w42/dr6273/work/projects/Aus_energy/"

# Load predictors and predictands

Detrended energy demand (predicand)

In [5]:
dem_da = xr.open_dataset(
    path + "/data/energy_demand/daily_demand_2010-2020_stl.nc"
)["demand_stl"]

Predictors. Load everything from relevant directory.

In [10]:
files = fn.get_predictor_files("NEM", "pop_dens_mask")
files

['/g/data/w42/dr6273/work/projects/Aus_energy/demand_predictors/mtpr_era5_daily_1959-2022_NEM_pop_dens_mask.nc',
 '/g/data/w42/dr6273/work/projects/Aus_energy/demand_predictors/10w_era5_daily_1959-2022_NEM_pop_dens_mask.nc',
 '/g/data/w42/dr6273/work/projects/Aus_energy/demand_predictors/msdwswrf_era5_daily_1959-2022_NEM_pop_dens_mask.nc',
 '/g/data/w42/dr6273/work/projects/Aus_energy/demand_predictors/rh_era5_daily_1959-2022_NEM_pop_dens_mask.nc',
 '/g/data/w42/dr6273/work/projects/Aus_energy/demand_predictors/cdd_24_rollmean3_era5_daily_1959-2022_NEM_pop_dens_mask.nc',
 '/g/data/w42/dr6273/work/projects/Aus_energy/demand_predictors/2t_rollmean4_era5_daily_1959-2022_NEM_pop_dens_mask.nc',
 '/g/data/w42/dr6273/work/projects/Aus_energy/demand_predictors/2tmin_era5_daily_1959-2022_NEM_pop_dens_mask.nc',
 '/g/data/w42/dr6273/work/projects/Aus_energy/demand_predictors/2tmax_era5_daily_1959-2022_NEM_pop_dens_mask.nc',
 '/g/data/w42/dr6273/work/projects/Aus_energy/demand_predictors/hdd_18_ro

In [13]:
pred_ds = xr.open_mfdataset(files, combine="nested", compat="override").compute()

In [14]:
pred_ds.data_vars

Data variables:
    mtpr      (region, time) float64 2.842e-05 4.244e-05 ... 1.614e-05 6.045e-06
    w10       (region, time) float64 3.489 3.602 3.166 ... 1.758 3.059 3.369
    msdwswrf  (region, time) float64 266.6 293.4 257.6 ... 281.5 249.9 299.6
    rh        (region, time) float64 72.35 69.38 69.15 ... 65.67 76.73 71.14
    cdd3      (region, time) float64 nan nan 0.2343 ... 5.632e-05 5.633e-05
    t2m4      (region, time) float64 nan nan nan 293.5 ... 290.4 290.1 289.0
    t2min     (region, time) float64 nan 289.1 289.3 290.0 ... 281.9 287.8 287.9
    t2max     (region, time) float64 nan 297.4 297.7 300.2 ... 291.1 293.3 297.2
    hdd4      (region, time) float64 nan nan nan 0.5197 ... 2.288 2.441 2.455
    q         (region, time) float64 10.62 10.2 10.42 ... 6.549 9.237 9.997
    cdd4      (region, time) float64 nan nan nan ... 0.1403 0.1403 5.633e-05
    hdd       (region, time) float64 1.053 0.5182 0.3264 ... 4.067 1.35 0.06106
    cdd       (region, time) float64 0.2843 0.

In [15]:
pred_ds

In [16]:
def ts(ax, da, region):
    """
    Time series plot.
    
    ax: axes to draw on
    ds: dataarray
    region: region name
    """
    da.sel(region=region).plot(ax=ax, alpha=0.7, lw=0.5)
    ax.set_title("")

In [17]:
def scatter(ax, y_da, x_ds, x_var, region):
    """
    Scatter plot of y_da against x_var.
    
    ax: axes to draw on
    y_da: dataarray of predictand
    x_ds: dataset of predictors
    x_var: data_var name
    region: region name
    """
    def K_to_C(da):
        return da - 273.15
    
    x_da = x_ds[x_var]
    if x_var in ["t2m", "t2min", "t2max", "t2m3", "t2m4"]:
        x_da = K_to_C(x_da)
        
    ax.scatter(
        x_da.sel(region=region, time=y_da["time"]),
        y_da.sel(region=region) / 1000,
        s=1,
        alpha=0.3
    )
    ax.set_ylabel("Demand [GWh/day]")
    ax.set_xlabel(x_var)

In [18]:
# Sort the data variables into something more logical
vars_sorted = ["t2m", "t2m3", "t2m4", "t2min", "t2max",
               "cdd", "cdd3", "cdd4",
               "hdd", "hdd3", "hdd4",
               "msdwswrf", "mtpr", "rh", "w10"]

In [19]:
def facet(plot_fn_name, y_da, x_ds, region, filename=None):
    """
    16 subplots - one for each predictor and the predictand.
    
    plot_fn_name: str, 'scatter' or 'ts' to indicate which plotting function to use
    y_da: array of predictand
    x_ds: dataset of predictors
    region: region name
    filename: filename for savefig
    """
    fig, ax = plt.subplots(5, 3, figsize=(7, 9), dpi=100)
    
    for i, v in enumerate(vars_sorted):
        if (i == 0) & (plot_fn_name == "ts"):
            ts(ax[0,0], y_da / 1000, region)
        elif (i > 0) & (plot_fn_name == "ts"):
            ts(ax.flatten()[i], x_ds[v], region)
        elif plot_fn_name == "scatter":
            scatter(ax.flatten()[i], y_da, x_ds, v, region)
        else:
            raise ValueError("Incorrect plot_fn_name")
        
    plt.tight_layout()
    
    if filename is not None:
        plt.savefig("./figures/" + filename, format="pdf", dpi=400, bbox_inches="tight")
        
    plt.close()

In [20]:
for plot_name in ["ts", "scatter"]:
    for region in pred_ds["region"].values:
        fp = plot_name + "_" + region + ".pdf"
        facet(plot_name, dem_da, pred_ds, region, fp)