# python setup

In [1]:
import numpy as np
import pandas as pd
import xarray as xr
from glob import glob
import matplotlib.pyplot as plt
import flox
from pathlib import Path

# define filepaths
f_plot_ptrn = "plots/bias_crctd_mrms_qaqc/"
fldr_mrms_constant_tstep = "D:/Dropbox/_GradSchool/_norfolk/highres-radar-rainfall-processing/data/mrms_nc_preciprate_fullres_dailyfiles_constant_tstep/"
lst_fs_qaqc = glob(fldr_mrms_constant_tstep + "*qaqc*.nc")
f_nc_qc_daily_st4res = fldr_mrms_constant_tstep + "_qaqc_of_resampled_data_st4_res.nc"

# load datasets
# ds_qc_daily_st4res = xr.open_dataset(f_nc_qc_daily_st4res, chunks = dict(latitude=-1,longitude=-1,date = -1))

In [8]:
# data preprocessing
## add year_month coordinate to datasets
# ds_qc_daily_st4res.coords["year_month"] = ds_qc_daily_st4res['date'].dt.strftime('%Y-%m')

# General Functions

In [2]:
def return_datasets_of_interest(lst_fs_qaqc, res, var, lst_da_vars_to_subset = None):
    # make sure arguments are valid
    lst_res = ["st4", "mrms"]
    lst_vars = ["quants", "max", "mean", "min", "sum"]
    lst_da_varnames=  ['frac_of_tot_biascrctd_rain_from_stageiv_fill', 'hours_of_stageiv_fillvalues', 'max_daily_correction_factor',
                'mean_daily_correction_factor', 'mrms_biascorrected_daily_totals_mm', 'mrms_biascorrected_minus_stageiv_mm',
                    'mrms_nonbiascorrected_daily_totals_mm', 'mrms_nonbiascorrected_minus_stageiv_mm', 'q0.1_correction_factor',
                    'q0.5_correction_factor', 'q0.9_correction_factor', 'stageiv_daily_totals_mm', 'total_stageiv_fillvalues_mm']
    if res not in lst_res:
        print("res argument must be on of {}".format(lst_res))
        return
    if var not in lst_vars:
        print("var argument must be on of {}".format(lst_vars))
        return
    if lst_da_vars_to_subset is not None:
        for varname in lst_da_vars_to_subset:
            lst_problem_varnames = []
            if varname not in lst_da_varnames:
                lst_problem_varnames.append(varname)
        if len(lst_problem_varnames) > 0:
            print("List of da variable names is not valid. Problem varnames: ")
            print(lst_problem_varnames)
            return

    # create a list of only the files with the desired resolution
    lst_fs_qaqc_res = []
    for f in lst_fs_qaqc:
        if (res == "mrms") and ("st4" not in f):
            lst_fs_qaqc_res.append(f)
        if (res == "st4") and ("st4" in f):
            lst_fs_qaqc_res.append(f)
    # create xarray datsets of the desired variables and resolution
    lst_files_opened = []
    for f in lst_fs_qaqc_res:
        if var in f:
            if "date.month" in f:
                ds_month = xr.open_dataset(f)
                lst_files_opened.append(f)
                # print("opening file {}".format(f))
            if "date.year" in f:
                ds_yearly = xr.open_dataset(f)
                lst_files_opened.append(f)
                # print("opening file {}".format(f))
            if "year_month" in f:
                ds_year_mnth = xr.open_dataset(f)
                lst_files_opened.append(f)
                # print("opening file {}".format(f))
    if len(lst_files_opened) != 3:
        print("WARNING: More files opened than expected indicating an issue in recognizing the file naming pattern. Files opened:")
        for f in lst_files_opened:
            lst_files_opened.append(f)
    if lst_da_vars_to_subset is not None:
        return ds_yearly[lst_da_vars_to_subset], ds_year_mnth[lst_da_vars_to_subset], ds_month[lst_da_vars_to_subset]
    return ds_yearly, ds_year_mnth, ds_month

def plot_multipanel_plot(ds, parent_fldr, lst_vars, lst_cbar_lab = None, lst_vmin = None, lst_vmax = None, lst_cmap = None, figpath_suffix="",
                         fig_title_suffix = ""):
    Path(parent_fldr).mkdir(parents=True, exist_ok=True)
    figsize = (13,4)
    # return agg var
    for coord in ds.coords:
        if coord not in ["longitude", "latitude", "spatial_ref"]:
            agg = coord
    agg_vars = pd.unique(ds[agg].values)
    for agg_var in agg_vars:
        ds_sub = ds.sel({agg:agg_var})
        if agg == "date":
            agg_var = str(agg_var)[0:10]
        fig_fpath = "{}{}_{}_{}.png".format(parent_fldr,agg,agg_var,figpath_suffix)

        ncols = int(np.ceil(np.sqrt(len(lst_vars))))
        nrows = int(np.ceil(len(lst_vars)/ncols))
        fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize = (ncols*4.75, nrows*4), dpi = 250)
        count = -1
        for row_id, row in enumerate(axes):
            for ax in row:
                count += 1
                if count == len(lst_vars):
                    break
                var_to_plot = lst_vars[count]
                dict_cbar_args = {}
                pcolormesh_args = {}
                if lst_cbar_lab is not None:
                    dict_cbar_args["label"] = lst_cbar_lab[count]
                if lst_vmin is not None:
                    pcolormesh_args["vmin"] = lst_vmin[count]
                if lst_vmax is not None:
                    pcolormesh_args["vmax"] = lst_vmax[count]
                if lst_cmap is not None:
                    pcolormesh_args["cmap"] = lst_cmap[count]
                # print("pcolormesh_args: {}".format(pcolormesh_args))
                # print("dict_cbar_args: {}".format(dict_cbar_args))
                ds_sub[var_to_plot].plot.pcolormesh(ax = ax, **pcolormesh_args, cbar_kwargs=dict_cbar_args)
                ax.set_title(var_to_plot)
        fig.suptitle("{} {} | {}".format(agg, agg_var, fig_title_suffix))
        fig.tight_layout()
        plt.savefig(fig_fpath)
        plt.close()

# define functions
def determine_vmax_for_multiple_vars(ds, lst_vars, quantile_cutoff = 0.98):
    vmax = -9999
    for var in lst_vars:
        da = ds[var]
        vmax_var = da.to_dataframe().quantile(quantile_cutoff).max()
        if vmax_var > vmax:
            vmax = vmax_var
    print("vmax = {}".format(vmax))
    return vmax

# QAQC

## Stage IV, MRMS Differences and Stage IV Fill Values

In [10]:
# load data
lst_da_vars_to_subset = ["mrms_biascorrected_minus_stageiv_mm", "mrms_biascorrected_daily_totals_mm", "mrms_nonbiascorrected_daily_totals_mm",
                         "mrms_nonbiascorrected_minus_stageiv_mm", "stageiv_daily_totals_mm", "total_stageiv_fillvalues_mm"]
var = "sum"
res = "st4"
ds_yearly, ds_year_mnth, ___ = return_datasets_of_interest(lst_fs_qaqc, res=res, var=var,
                                                                lst_da_vars_to_subset = lst_da_vars_to_subset)

lst_da_minuses = ["mrms_biascorrected_minus_stageiv_mm","mrms_nonbiascorrected_minus_stageiv_mm"]
lst_da_raindepths = ["mrms_biascorrected_daily_totals_mm", "mrms_nonbiascorrected_daily_totals_mm", "stageiv_daily_totals_mm"]
lst_fillvals = ["total_stageiv_fillvalues_mm"]

# ds = ds_yearly
# cbar_quant_cutoff = 0.98

def return_list_of_plotting_params(ds, cbar_quant_cutoff  = 0.98):
    ds_quants = ds.quantile(q = cbar_quant_cutoff)
    ds_quants_low = ds.quantile(q = (1-cbar_quant_cutoff))
    ds_min = ds.min()
    lst_vmin = []
    lst_vmax = []
    lst_cmap = []
    lst_cbar_lab = []
    for da_varname in lst_da_vars_to_subset:
        if "minus" in da_varname:
            lim = ds_quants[lst_da_minuses].load().to_pandas().max()
            lim2 = ds_quants_low[lst_da_minuses].load().to_pandas().abs().max()
            lim = max(lim, lim2)
            lst_vmin.append(-lim)
            lst_vmax.append(lim)
            lst_cmap.append("seismic_r")
        elif "fillvalues" not in da_varname:
            lim = ds_quants[lst_da_raindepths].load().to_pandas().max()
            low_lim = ds_quants_low[lst_da_raindepths].load().to_pandas().min()
            lst_vmin.append(low_lim)
            lst_vmax.append(lim)
            lst_cmap.append("GnBu")
        else:
            lim = ds_quants[lst_fillvals].load().to_pandas().max()
            low_lim = ds_min[lst_fillvals].load().to_pandas().min()
            lst_vmin.append(low_lim)
            lst_vmax.append(lim)
            lst_cmap.append("binary")
        lst_cbar_lab.append("")
    return lst_vmin, lst_vmax, lst_cmap, lst_cbar_lab

fig_title_suffix = "aggregation statistic: {} | resolution: {}".format(var, res)
lst_vars = lst_da_vars_to_subset

In [11]:
# ds = ds_subset
# cbar_quant_cutoff = 0.98

# ds_quants = ds.quantile(q = cbar_quant_cutoff)
# ds_quants_low = ds.quantile(q = (1-cbar_quant_cutoff))
# ds_min = ds.min()
# lst_vmin = []
# lst_vmax = []
# lst_cmap = []
# lst_cbar_lab = []
# for da_varname in lst_da_vars_to_subset:
#     if "minus" in da_varname:
#         lim = ds_quants[lst_da_minuses].load().to_pandas().max()
#         lim2 = ds_quants_low[lst_da_minuses].load().to_pandas().abs().max()
#         lim = max(lim, lim2)
#         lst_vmin.append(-lim)
#         lst_vmax.append(lim)
#         lst_cmap.append("seismic_r")
#     elif "fillvalues" not in da_varname:
#         lim = ds_quants[lst_da_raindepths].load().to_pandas().max()
#         low_lim = ds_quants_low[lst_da_raindepths].load().to_pandas().min()
#         lst_vmin.append(low_lim)
#         lst_vmax.append(lim)
#         lst_cmap.append("GnBu")
#     else:
#         lim = ds_quants[lst_fillvals].load().to_pandas().max()
#         low_lim = ds_min[lst_fillvals].load().to_pandas().min()
#         lst_vmin.append(max(low_lim,0))
#         lst_vmax.append(lim)
#         lst_cmap.append("binary")
#     lst_cbar_lab.append("")

In [12]:
# ds_quants[lst_da_minuses].load().to_pandas()

### Yearly

In [13]:
parent_fldr = f_plot_ptrn + "annual_sum_diffs_fillvals/"

lst_vmin, lst_vmax, lst_cmap, lst_cbar_lab = return_list_of_plotting_params(ds_yearly, cbar_quant_cutoff  = 0.98)
plot_multipanel_plot(ds_yearly, parent_fldr, lst_vars, lst_cbar_lab = lst_cbar_lab, lst_vmin = lst_vmin, lst_vmax = lst_vmax, lst_cmap = lst_cmap, figpath_suffix="sum_diffs",
                     fig_title_suffix = fig_title_suffix)

### Year-Month

In [14]:
parent_fldr = f_plot_ptrn + "year_month_sum_diffs_fillvals/"

lst_vmin, lst_vmax, lst_cmap, lst_cbar_lab = return_list_of_plotting_params(ds_year_mnth, cbar_quant_cutoff  = 0.98)
plot_multipanel_plot(ds_year_mnth, parent_fldr, lst_vars, lst_cbar_lab = lst_cbar_lab, lst_vmin = lst_vmin, lst_vmax = lst_vmax, lst_cmap = lst_cmap, figpath_suffix="sum_diffs",
                     fig_title_suffix = fig_title_suffix)

### Looking into January and February 2015

In [10]:
ds_qc_daily_st4res = xr.open_dataset(f_nc_qc_daily_st4res, chunks = dict(latitude=-1,longitude=-1,date = -1))
ds_subset = ds_qc_daily_st4res[lst_da_vars_to_subset]
ds_subset = ds_subset.sel(date = "02-2015").load()
# ds_subset.total_stageiv_fillvalues_mm.to_dataframe().total_stageiv_fillvalues_mm.unique().shape
ds_subset

NameError: name 'lst_da_vars_to_subset' is not defined

In [16]:
xr.open_dataset(f_nc_qc_daily_st4res)

In [17]:
# load data
ds_qc_daily_st4res = xr.open_dataset(f_nc_qc_daily_st4res, chunks = dict(latitude=-1,longitude=-1,date = -1))
ds_subset = ds_qc_daily_st4res[lst_da_vars_to_subset]
ds_subset = ds_subset.sel(date = "01-03-2015").load()
# agg = "date"
# agg_vars = pd.unique(ds[agg].values)
# str(agg_vars[0])[0:10]
# ds_subset
ds_subset

In [18]:
ds_subset.total_stageiv_fillvalues_mm.to_dataframe().total_stageiv_fillvalues_mm.unique()

array([5.73116206e-02, 1.01556487e-01, 4.36600065e-03, ...,
       1.19142405e-05, 2.38492532e-04, 4.91924554e-01], dtype=float32)

In [19]:
parent_fldr = f_plot_ptrn + "jan2015_sum_diffs_fillvals/"

lst_vmin, lst_vmax, lst_cmap, lst_cbar_lab = return_list_of_plotting_params(ds_subset, cbar_quant_cutoff  = 0.98)


In [23]:
# plot_multipanel_plot(ds_subset, parent_fldr, lst_vars, lst_cbar_lab = lst_cbar_lab, lst_vmin = lst_vmin, lst_vmax = lst_vmax, lst_cmap = lst_cmap, figpath_suffix="sum_diffs",
#                      fig_title_suffix = fig_title_suffix)

## Hexplots of Stage IV vs. MRMS Daily Totals

In [3]:
# load data
ds_qc_daily_st4res = xr.open_dataset(f_nc_qc_daily_st4res, chunks = dict(latitude=-1,longitude=-1,date = -1))
ds_subset = ds_qc_daily_st4res[["mrms_biascorrected_daily_totals_mm", "mrms_nonbiascorrected_daily_totals_mm","stageiv_daily_totals_mm"]]

In [6]:
# define functions for computing whether points are within 10% of 1 to 1 line
def comp_dist_to_1to1_line(df, xvar, yvar):
    # https://en.wikipedia.org/wiki/Distance_from_a_point_to_a_line
    b = 1
    a = -1*b
    c = 0
    x0 = df[xvar]
    y0 = df[yvar]

    dist = np.abs(a * x0 + b * y0 + c) / np.sqrt(a**2 + b**2)
    nearest_x = (b * (b*x0 - a * y0) - a * c) / (a**2 + b**2)
    nearest_y = (a*(-b*x0 + a*y0) - b*c)/(a**2 + b**2)

    dist.name = "distance_to_1to1_line"
    nearest_x.name = "x_coord_of_nearest_pt_on_1to1_line"
    nearest_y.name = "y_coord_of_nearest_pt_on_1to1_line"

    return dist, nearest_x, nearest_y

def compute_frac_within_tolerance_of_1to1_line(dist, nearest_x, frac_tol):
    s_within_tol = (dist <= nearest_x*(frac_tol/2))
    return s_within_tol.sum()/len(s_within_tol)

# define hexbin plot function
def plt_hexbin(df, frac_tol = 0.1, col1="mrms_biascorrected_daily_totals_mm", col2="mrms_nonbiascorrected_daily_totals_mm",
                refcol ="stageiv_daily_totals_mm",logcount = True, gridcnt = 40, fig_fpath = None,
                fig_title = None):
    fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(10, 4), dpi = 300)

    ylim = xlim = (0, df.loc[:,refcol ].max())
    nx = gridcnt
    ny = int(round(nx / np.sqrt(3),0))
    extent = xlim[0], xlim[1], ylim[0], ylim[1]
    # xlim = (0, max(df.loc[:, col1].max(), df.loc[:, col2].max()))
    ax0.set(xlim=xlim, ylim=ylim)
    ax1.set(xlim=xlim, ylim=ylim)

    if logcount:
        hb1 = ax0.hexbin(df.loc[:, col1], df.loc[:,refcol ], bins='log', cmap='inferno',mincnt=5,gridsize=(nx, ny), extent = extent)
        hb2 = ax1.hexbin(df.loc[:, col2], df.loc[:,refcol ], bins='log', cmap='inferno',mincnt=5,gridsize=(nx, ny), extent = extent)
    else:
        hb1 = ax0.hexbin(df.loc[:, col1], df.loc[:,refcol ], cmap='inferno', mincnt=5,gridsize=(nx, ny), extent = extent)
        hb2 = ax1.hexbin(df.loc[:, col2], df.loc[:,refcol ], cmap='inferno', mincnt=5,gridsize=(nx, ny), extent = extent)

    # compute frac of points within tolerance
    dist1, nearest_x1, nearest_y1 = comp_dist_to_1to1_line(df, col1, refcol)
    frac_pts_in_tol1 = compute_frac_within_tolerance_of_1to1_line(dist1, nearest_x1, frac_tol)
    dist2, nearest_x2, nearest_y2 = comp_dist_to_1to1_line(df, col2, refcol)
    frac_pts_in_tol2 = compute_frac_within_tolerance_of_1to1_line(dist2, nearest_x2, frac_tol)

    ax0.set_xlabel(df.loc[:, col1].name)
    ax0.set_ylabel(df.loc[:,refcol ].name)
    ax1.set_xlabel(df.loc[:, col2].name)
    ax1.set_ylabel(df.loc[:,refcol ].name)

    ax0.set_title("Percent of observations within {}% of the 1:1 line: {}%".format(int(frac_tol*100), round(frac_pts_in_tol1*100, 2)),
                  fontsize = 9)
    ax1.set_title("Percent of observations within {}% of the 1:1 line: {}%".format(int(frac_tol*100), round(frac_pts_in_tol2*100, 2)),
                  fontsize = 9)
    # add lines showing tolerance threshold
    x0 = y0 = np.linspace(0, xlim, 500)
    dist = x0 * frac_tol/2
    # based on a^2 + b^2 = c^2, assuming 1:1 line so a=b
    x_upper = x0 - np.sqrt(dist**2/2)
    y_upper = y0 + np.sqrt(dist**2/2)
    x_lower = x0 + np.sqrt(dist**2/2)
    y_lower = y0 - np.sqrt(dist**2/2)

    ax0.plot(x_upper, y_upper, label = "upper bound", c = "cyan", ls = "--", linewidth = 0.7, alpha = 0.8)
    ax0.plot(x_lower, y_lower, label = "lower bound", c = "cyan", ls = "--", linewidth = 0.7, alpha = 0.8)

    ax1.plot(x_upper, y_upper, label = "upper bound", c = "cyan", ls = "--", linewidth = 0.7, alpha = 0.8)
    ax1.plot(x_lower, y_lower, label = "lower bound", c = "cyan", ls = "--", linewidth = 0.7, alpha = 0.8)

    # ax1.set_title("With a log color scale")
    cb1 = fig.colorbar(hb1, ax=ax0, label='')
    cb2 = fig.colorbar(hb2, ax=ax1, label='')
    ax0.axline((0, 0), slope=1, c = "cyan", ls = "--", linewidth = 1.2)
    ax1.axline((0, 0), slope=1, c = "cyan", ls = "--", linewidth = 1.2)
    # define lines based on upper and lower bounds of defined tolerance
    pts = nearest_y1[nearest_y1>0].quantile([0.1, 0.9])

    
    if fig_title is not None:
        fig.suptitle(fig_title)
    # fig.tight_layout()
    if fig_fpath is not None:
        plt.savefig(fig_fpath)
    plt.close()

    return

### Yearly Plots

In [8]:
years = pd.unique(ds_subset.date.dt.year.values)
#%% work
# years = years[0:1]
#%% end work
for yr in years:
    ds_sub_yr = ds_subset.where(ds_subset.date.dt.year == yr, drop = True)
    df_subset = ds_sub_yr.to_dataframe()
    plt_hexbin(df_subset, fig_fpath = "{}annual_hexbin_daily_comp/year_{}_hexbin_daily_tots.png".format(f_plot_ptrn,yr), fig_title = "year {} daily total comparison".format(yr))

### Year-Month Plots

In [9]:
ds_subset.coords["year_month"] = ds_subset['date'].dt.strftime('%Y-%m')
yearmonths = pd.unique(ds_subset.year_month.values)
#%% work
# yearmonths = yearmonths[0:1]
#%% end work
for yrmnth in yearmonths:
    ds_sub_yrmnth = ds_subset.where(ds_subset.year_month == yrmnth, drop = True)
    df_subset = ds_sub_yrmnth.to_dataframe()
    plt_hexbin(df_subset, fig_fpath = "{}yearmonth_hexbin_daily_comp/yearmonth_{}_hexbin_daily_tots.png".format(f_plot_ptrn,yrmnth), fig_title = "year-month {} daily total comparison".format(yrmnth))

## Bias Correction Factors

In [28]:
lst_da_vars_to_subset = ["mean_daily_correction_factor", "q0.1_correction_factor","q0.5_correction_factor", "q0.9_correction_factor"]
var = "quants"
res = "mrms"
ds_yearly, ds_year_mnth, ds_month = return_datasets_of_interest(lst_fs_qaqc, res=res, var=var,
                                                                lst_da_vars_to_subset = lst_da_vars_to_subset)

In [29]:
def plot_bias_crxn_facs(var, dic_subset = None):
    # load data
    lst_da_vars_to_subset = ["mean_daily_correction_factor", "q0.1_correction_factor","q0.5_correction_factor", "q0.9_correction_factor"]
    # var = "mean"
    res = "mrms"
    ds_yearly, ds_year_mnth, ds_month = return_datasets_of_interest(lst_fs_qaqc, res=res, var=var,
                                                                    lst_da_vars_to_subset = lst_da_vars_to_subset)
    
    if dic_subset is not None:
        ds_yearly = ds_yearly.sel(dic_subset)
        ds_year_mnth = ds_year_mnth.sel(dic_subset)
        ds_month = ds_month.sel(dic_subset)
    
    parent_fldr = f_plot_ptrn + "annual_{}_daily_crxn_factor/".format(var)
    Path(parent_fldr).mkdir(parents=True, exist_ok=True)
    fig_title_suffix = "aggregation statistic: {} | resolution: {}".format(var, res)
    lst_vars = lst_da_vars_to_subset

    lst_vmin = []
    lst_vmax = []
    lst_cmap = []
    lst_cbar_lab = []
    for da_varname in lst_da_vars_to_subset:
        lst_vmin.append(0)
        lst_vmax.append(2)
        lst_cbar_lab.append("")
        lst_cmap.append("seismic_r")

    plot_multipanel_plot(ds_yearly, parent_fldr, lst_vars, lst_cbar_lab = lst_cbar_lab, lst_vmin = lst_vmin, lst_vmax = lst_vmax, lst_cmap = lst_cmap, figpath_suffix="mean_crxn",
                        fig_title_suffix = fig_title_suffix)
    
    parent_fldr = f_plot_ptrn + "monthly_{}_daily_crxn_factor/".format(var)
    Path(parent_fldr).mkdir(parents=True, exist_ok=True)
    plot_multipanel_plot(ds_month, parent_fldr, lst_vars, lst_cbar_lab = lst_cbar_lab, lst_vmin = lst_vmin, lst_vmax = lst_vmax, lst_cmap = lst_cmap, figpath_suffix="mean_crxn",
                        fig_title_suffix = fig_title_suffix)
    
    parent_fldr = f_plot_ptrn + "year_{}_mean_daily_crxn_factor/".format(var)
    Path(parent_fldr).mkdir(parents=True, exist_ok=True)
    plot_multipanel_plot(ds_year_mnth, parent_fldr, lst_vars, lst_cbar_lab = lst_cbar_lab, lst_vmin = lst_vmin, lst_vmax = lst_vmax, lst_cmap = lst_cmap, figpath_suffix="mean_crxn",
                        fig_title_suffix = fig_title_suffix)

In [30]:
plot_bias_crxn_facs(var="mean")

In [31]:
dic_subset = {"quantile":0.5}
plot_bias_crxn_facs(var="quants", dic_subset=dic_subset)

# Archived

## Rainfall Totals

In [None]:
# load data
ds_yearly, ds_year_mnth, __ = return_datasets_of_interest(lst_fs_qaqc, res="st4", var="sum",
                                                                lst_da_vars_to_subset = ["mrms_biascorrected_daily_totals_mm",
                                                                                         'stageiv_daily_totals_mm',
                                                                                         'mrms_nonbiascorrected_daily_totals_mm'])

In [None]:
# define functions
def determine_vmax_for_raintots(ds_sum, quantile_cutoff = 0.98, var_to_use = None):
    df_sum = ds_sum.to_dataframe()
    if var_to_use is not None:
        vmax = df_sum[var_to_use].quantile(quantile_cutoff)
    else:
        print(df_sum.quantile(quantile_cutoff))
        vmax = df_sum.quantile(quantile_cutoff).max()
    print("vmax = {}".format(vmax))
    return vmax

def plot_rainfall_tots(ds_sum, agg, vmax, parent_fldr):
    figsize = (13,4)

    varname_left = "mrms_nonbiascorrected_daily_totals_mm"
    varname_middle = "stageiv_daily_totals_mm"
    varname_right = "mrms_biascorrected_daily_totals_mm"

    agg_vars = pd.unique(ds_sum[agg].values)
    
    # return name of new coordinate
    for coord in ds_sum.coords:
        if coord not in ["longitude", "latitude", "spatial_ref"]:
            new_coord = coord

    for agg_var in agg_vars:
        ds_sub = ds_sum.sel({new_coord:agg_var})
        fig_fpath = "{}{}_{}_totals.png".format(parent_fldr,agg,agg_var)

        da_left = ds_sub[varname_left]
        da_middle = ds_sub[varname_middle].where(ds_sub[varname_middle]>=0, 0)
        # da_st4_no_negs = ds_sub.stageiv_daily_totals_mm.where(ds_sub.stageiv_daily_totals_mm>=0, 0)
        da_right = ds_sub[varname_right]

        fig, axes = plt.subplots(ncols=3, figsize = figsize, dpi = 250)

        da_left.plot.pcolormesh(ax = axes[0], cbar_kwargs={"label": "rainfall (mm)"}, vmin = 0, vmax=vmax,cmap = "seismic_r")
        da_middle.plot.pcolormesh(ax = axes[1], cbar_kwargs={"label": "rainfall (mm)"}, vmin = 0, vmax=vmax,cmap = "seismic_r")
        da_right.plot.pcolormesh(ax = axes[2], cbar_kwargs={"label": "rainfall (mm)"}, vmin = 0, vmax=vmax,cmap = "seismic_r")
        axes[0].set_title(varname_left)
        axes[1].set_title(varname_middle)
        axes[2].set_title(varname_right)  
        fig.suptitle("{} {}".format(agg, agg_var))
        fig.tight_layout()
        plt.savefig(fig_fpath)
        plt.close()

### Yealy Statistics

In [None]:
vmax_yrmnth = determine_vmax_for_raintots(ds_year_mnth, quantile_cutoff = 0.98)
plot_rainfall_tots(ds_year_mnth, "year_month", vmax_yrmnth, parent_fldr = f_plot_ptrn + "year_monthly_totals/")

mrms_biascorrected_daily_totals_mm       347.477296
stageiv_daily_totals_mm                  322.241970
mrms_nonbiascorrected_daily_totals_mm    339.271205
spatial_ref                                0.000000
Name: 0.98, dtype: float64
vmax = 347.47729553222644


### Year-Month Statistics

In [None]:
vmax_yr = determine_vmax_for_raintots(ds_yearly, quantile_cutoff = 0.98)
plot_rainfall_tots(ds_yearly, "year", vmax_yr,parent_fldr = f_plot_ptrn + "annual_totals/")

mrms_biascorrected_daily_totals_mm       2819.901582
stageiv_daily_totals_mm                  2669.776934
mrms_nonbiascorrected_daily_totals_mm    2626.535488
spatial_ref                                 0.000000
Name: 0.98, dtype: float64
vmax = 2819.9015820312497
