In [1]:
import pandas as pd
import numpy as np
import pickle
import json
import matplotlib.pyplot as plt
from matplotlib.patches import Patch as mpatch
from matplotlib.lines import Line2D as mline
import seaborn as sns

In [2]:
plt.style.use("seaborn")
sns.set_context("talk")
style_colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [3]:
tree_results = {
    i:pd.read_pickle(f"../../results/basin_eval/{i}/treed_model/results.pickle")
    for i in ["upper_col", "lower_col", "tva", "missouri", "pnw"]
}

In [5]:
simp_results = {}
for i in ["upper_col", "lower_col", "tva", "missouri", "pnw"]:
    try:
        df = pd.read_pickle(f"../../results/basin_eval/{i}/simple_model/results.pickle")
        simp_results[i] = df
    except FileNotFoundError as e:
        pass


In [12]:
def get_tree_coefs(tree_results):
    output = {}
    for basin, results in tree_results.items():
        coefs = results["coefs"]
        coefs = coefs.rename(columns={j:i+1 for i,j in enumerate(coefs.columns)})
        X = results["data"]["X_train"]
        X["group"] = results["data"]["groups"]
        minr = X.groupby("group")["release_pre"].min()
        maxr = X.groupby("group")["release_pre"].max()
        tree = pd.DataFrame({"min_r":minr, "max_r":maxr})
        tree = tree.rename(index={j:i+1 for i,j in enumerate(tree.index)})
        output[basin] = {"coefs":coefs, "tree":tree}
    return output

In [13]:
tree_coefs = get_tree_coefs(tree_results)

In [32]:
def get_tree_breaks(tree_coefs):
    return pd.DataFrame({i:j["tree"]["max_r"] for i,j in tree_coefs.items()})

In [33]:
def get_var_compare(tree_coefs, var):
    return pd.DataFrame({i:j["coefs"].loc[var] for i,j in tree_coefs.items()})

In [35]:
get_tree_breaks(tree_coefs).T.round(3)

group,1,2,3,4,5,6,7,8
upper_col,-0.676,-0.233,0.237,0.846,1.577,2.716,4.592,12.792
lower_col,-1.622,-0.989,-0.485,-0.047,0.386,0.903,1.299,3.518
tva,-0.879,-0.5,-0.027,0.423,1.034,1.8,2.783,19.166
missouri,-0.492,-0.219,0.144,0.611,1.431,2.546,4.172,33.144
pnw,-0.894,-0.349,0.052,0.596,1.283,2.117,3.393,6.608


In [31]:
get_var_compare(tree_coefs, "sto_diff").T.round(3)

Unnamed: 0,1,2,3,4,5,6,7,8
upper_col,-0.003,0.096,0.086,0.182,0.323,0.341,0.207,0.129
lower_col,-0.031,-0.02,-0.027,0.015,0.053,0.005,0.016,-0.004
tva,1.71,2.006,2.857,2.947,2.983,2.639,1.293,1.207
missouri,-0.024,0.132,0.475,0.41,0.321,0.283,0.969,-0.473
pnw,0.004,0.005,0.008,-0.014,-0.018,-0.019,0.003,-0.007


In [36]:
get_var_compare(tree_coefs, "release_pre").T.round(3)

Unnamed: 0,1,2,3,4,5,6,7,8
upper_col,0.43,0.635,0.75,0.787,0.883,1.081,0.98,1.007
lower_col,0.477,0.546,0.478,0.514,0.475,0.563,0.69,0.602
tva,0.557,0.58,0.616,0.577,0.63,0.645,0.655,0.751
missouri,0.82,0.978,0.993,1.077,1.076,1.11,1.031,0.64
pnw,0.521,0.655,0.663,0.788,0.802,0.78,0.843,0.65


In [37]:
get_var_compare(tree_coefs, "inflow").T.round(3)

Unnamed: 0,1,2,3,4,5,6,7,8
upper_col,0.009,0.007,-0.003,0.07,0.126,0.119,0.242,0.153
lower_col,0.053,0.091,0.108,0.153,0.164,0.112,0.019,0.013
tva,-0.061,-0.037,-0.074,-0.034,-0.076,-0.026,-0.029,-0.062
missouri,-0.023,0.003,-0.046,-0.007,0.056,0.086,0.096,0.301
pnw,0.372,0.29,0.252,0.182,0.167,0.182,0.145,0.348


In [38]:
get_var_compare(tree_coefs, "storage_x_inflow").T.round(3)

Unnamed: 0,1,2,3,4,5,6,7,8
upper_col,0.061,0.072,0.072,0.065,0.051,0.058,0.033,0.089
lower_col,0.069,-0.048,-0.052,-0.111,-0.113,-0.097,0.033,0.053
tva,0.066,0.111,0.1,0.064,0.105,0.03,-0.017,0.067
missouri,0.055,0.057,0.098,0.086,0.062,0.033,0.082,0.217
pnw,0.06,0.034,0.001,0.046,0.023,0.016,-0.003,0.009


In [47]:
def get_group_coefs(simp_results, group):
    coefs = {}
    for i,j in simp_results.items():
        try:
            coefs[i] = j["coefs"][group] 
        except KeyError as e:
            pass
    return pd.DataFrame(coefs)

In [51]:
get_group_coefs(simp_results, "ror").T.round(3)

Unnamed: 0,const,inflow,storage_pre,release_pre,storage_roll7,inflow_roll7,release_roll7,storage_x_inflow
tva,0.0,1.38,0.143,0.019,-0.024,0.001,0.01,-0.445
missouri,0.0,0.871,0.037,0.107,-0.033,-0.098,0.083,0.038


In [50]:
get_group_coefs(simp_results, "low_rt").T.round(3)

Unnamed: 0,const,inflow,storage_pre,release_pre,storage_roll7,inflow_roll7,release_roll7,storage_x_inflow
upper_col,-0.0,0.36,0.062,0.802,-0.053,-0.412,0.186,0.063
tva,0.0,0.929,0.172,0.326,-0.099,0.002,0.048,-0.319
pnw,-0.0,0.901,0.019,0.084,-0.004,0.002,0.0,0.015
