In [1]:
import numpy as np
import pandas as pd
import pickle as pkl 
import xarray as xr
import copy
import os
import sys 
import metrics
from sklearn.metrics import mutual_info_score
print("XArray version: ", xr.__version__)

XArray version:  0.16.1


In [2]:
##########################################################################################################
# LOAD IN THE DATA
##########################################################################################################
with open('./model_output_for_analysis/nwm_chrt_v2_1d_local.p', 'rb') as fb: 
    nwm_results = pkl.load(fb)

lstm_results_time_split1={}
mclstm_results_time_split1={}
sacsma_results_time_split1={}
lstm_results_time_split2={}
mclstm_results_time_split2={}
sacsma_results_time_split2={}

for forcing_type in ['nldas', 'daymet']:
    
    with open('./model_output_for_analysis/lstm_time_split1_{}_ens.p'.format(forcing_type), 'rb') as fb: 
        lstm_results_time_split1[forcing_type] = pkl.load(fb)
    with open('./model_output_for_analysis/mclstm_time_split1_{}_ens.p'.format(forcing_type), 'rb') as fb: 
        mclstm_results_time_split1[forcing_type] = pkl.load(fb)
    with open('./model_output_for_analysis/sacsma_time_split1_{}_ens.p'.format(forcing_type), 'rb') as fb: 
        sacsma_results_time_split1[forcing_type] = pkl.load(fb)

    with open('./model_output_for_analysis/lstm_time_split2_{}.p'.format(forcing_type), 'rb') as fb: 
        lstm_results_time_split2[forcing_type] = pkl.load(fb)
    with open('./model_output_for_analysis/mclstm_time_split2_{}.p'.format(forcing_type), 'rb') as fb: 
        mclstm_results_time_split2[forcing_type] = pkl.load(fb)
    with open('./model_output_for_analysis/sacsma_time_split2_{}.p'.format(forcing_type), 'rb') as fb: 
        sacsma_results_time_split2[forcing_type] = pkl.load(fb)

train_split_type_model_set = {'time_split1':{'nwm':nwm_results, 
                                           'lstm':lstm_results_time_split1,
                                            'mc':mclstm_results_time_split1,
                                            'sac':sacsma_results_time_split1},
                              'time_split2':{'nwm':nwm_results,
                                           'lstm':lstm_results_time_split2,
                                            'mc':mclstm_results_time_split2,
                                            'sac':sacsma_results_time_split2}}

In [3]:
##########################################################################################################
# USE A CONVERSION BETWEEN MODELS AND DATA
##########################################################################################################
# Convert flow to   CFS mm -> ft     km^2 -> ft^2    hr->s
conversion_factor = 0.00328084 * 10763910.41671 / 3600 / 24

In [4]:
##########################################################################################################
# Get all the CAMELS attributes.  
##########################################################################################################

# Camels attributes with RI information
dataName = '../data/camels_attributes.csv'
# load the data with pandas
pd_attributes = pd.read_csv(dataName, sep=',', index_col='gauge_id')

# Add the basin ID as a 8 element string with a leading zero if neccessary
basin_id_str = []
for a in pd_attributes.index.values:
    basin_id_str.append(str(a).zfill(8))
pd_attributes['basin_id_str'] = basin_id_str

In [5]:
##########################################################################################################
# Loop through all the SACSMA runs and check that the results are good. 
# Get a list of basins that has good calibration results.
##########################################################################################################
basin_list_all_camels = list(pd_attributes['basin_id_str'].values)
basin_list_sacsma_good = {ts:copy.deepcopy(basin_list_all_camels) for ts in ['time_split1', 'time_split2']}

for ib, basin_0str in enumerate(basin_list_all_camels): 
    remove_basin_id_from_list = False
    for train_split_type in ['time_split1', 'time_split2']:
        for forcing_type in ['nldas', 'daymet']:

            if basin_0str not in list(train_split_type_model_set[train_split_type]['sac'][forcing_type].columns):
                remove_basin_id_from_list = True
            elif train_split_type_model_set[train_split_type]['sac'][forcing_type][basin_0str].sum() <=0:
                remove_basin_id_from_list = True

            if train_split_type == 'time_split2' and forcing_type == 'nldas':
                if basin_0str not in list(train_split_type_model_set[train_split_type]['nwm'].keys()):
                    remove_basin_id_from_list = True

    if remove_basin_id_from_list:
        basin_list_sacsma_good[train_split_type].remove(basin_0str)

In [6]:
##########################################################################################################
#-------------------------------------------------------------------------------------------------
# Solve this problem. I think it is the xarray structures...
# isibleDeprecationWarning: Creating an ndarray from ragged nested sequences 
# (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. 
# If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
##########################################################################################################
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)

In [7]:
##########################################################################################################
# REVERT TO THESE AS THE FLOWS
##########################################################################################################
flows = ['lstm', 'mc', 'sac', 'obs']


In [8]:
def get_specifications(tsplt, forcing_type):
    """
    This function is designed to return specific details of the simulation period
    Inputs:
        tsplit (str): Either time_split2 or time_split1
        forcing_type (str): Either nldas or daymet
    Returns
        start_date (pd.Timestamp): The date the simulation period started
        end_date (pd.Timestamp): The date the simulation period ended
        labelz (dictionary): A mapping between short model name and long model name
        models (list): the short model names
        flows (list): the short model names plus "obs" for observed flow
        basin_list (list): the list of basins that meet the criteria for analysis
        tsplit (str): Either time_split2 or time_split1
    """
    if tsplt == 'time_split2' and forcing_type == 'nldas':
        start_date = pd.Timestamp('1996-10-01')
        end_date = pd.Timestamp('2014-01-01')
        labelz={'nwm':'NWM*', 'lstm':'LSTM', 'mc':'MC-LSTM','sac':'SAC-SMA', 'obs':'Observed'}
        models = ['nwm', 'lstm', 'mc', 'sac']
        flows = ['nwm', 'lstm', 'mc', 'sac', 'obs']
        basin_list = list(lstm_results_time_split2[forcing_type].keys())[:-1]
    elif tsplt == 'time_split2':
        start_date = pd.Timestamp('1996-10-01')
        end_date = pd.Timestamp('2014-01-01')
        labelz={'lstm':'LSTM', 'mc':'MC-LSTM','sac':'SAC-SMA', 'obs':'Observed'}
        models = ['lstm', 'mc', 'sac']
        flows = ['lstm', 'mc', 'sac', 'obs']
        basin_list = list(lstm_results_time_split2[forcing_type].keys())[:-1]
    else:
        start_date = pd.Timestamp('1989-10-01')
        end_date = pd.Timestamp('1999-09-30')
        labelz={'lstm':'LSTM', 'mc':'MC-LSTM','sac':'SAC-SMA', 'obs':'Observed'}
        models = ['lstm', 'mc', 'sac']
        flows = ['lstm', 'mc', 'sac', 'obs']
        basin_list = list(lstm_results_time_split1[forcing_type].keys())[:-1]

    spex = {"start_date":start_date,
            "end_date":end_date,
            "labelz":labelz,
            "models":models,
            "flows":flows, 
            "basin_list":basin_list,
            "tsplt":tsplt}
    return spex #(start_date, end_date, labelz, models, flows, basin_list)

In [9]:
def load_forcing_and_identify_events(tsplt, basin_0str, forcing_dir, file_name_map, forcing_type):
    """
    This function loads in the forcing, and also identifies the indices of precipitation "events"
    Events are arbitrarily defined as any time the precipitation is greater than the median (non zero) precip
    
    Inputs:
        tsplit (str): Either time_split2 or time_split1
        basin_0str (str): The basin ID as a string with a leading zero
        forcing_dir (str): The directory where to find the forcing file
        file_name_map (dictionary): 
        forcing_type (str): either nldas or daymet
    Return:
        forcing (pd.DataFrame): The forcing data for a particular basin
        precip_events (list): Indices of official precipitation "events"
    """
    basin_int = int(basin_0str)
    #-------------------------------------------------------------------------------------------------
    # FORCING
    forcing = pd.read_csv(f'{forcing_dir}{basin_0str}_lump_{file_name_map[forcing_type]}_forcing_leap.txt',
                          delim_whitespace=True, header=3)
    if tsplt == 'time_split1':
        forcing = forcing.iloc[3560:7214]
    if tsplt == 'time_split2':
        forcing = forcing.iloc[6118:]
    forcing.index=pd.to_datetime((forcing.Year*10000+forcing.Mnth*100+forcing.Day).apply(str),format='%Y%m%d')
    #-------------------------------------------------------------------------------------------------
    
    precip_threshold_50 = 0
    any_precip = forcing["PRCP(mm/day)"][forcing["PRCP(mm/day)"]>0].values
    any_precip.sort()
    onethrough = np.array([i for i in range(any_precip.shape[0])])/any_precip.shape[0]
    for i in range(any_precip.shape[0]):
        if onethrough[i] > .5:
            precip_threshold_50 = any_precip[i]
            break
            
    precip_events=[]
    # Get indices of precipitation events that have no such event two days prior, nor three days after
    for i, precip in enumerate(forcing["PRCP(mm/day)"]):
        if i < 3 or i > (forcing["PRCP(mm/day)"].shape[0]-3):
            continue
        if precip > precip_threshold_50:
            if np.sum(forcing["PRCP(mm/day)"][i-3:i]) < precip_threshold_50:
                if np.sum(forcing["PRCP(mm/day)"][i+1:i+5]) < precip_threshold_50:
                    precip_events.append(i)
                
    print("Number of precipitation events", len(precip_events))
    
    return forcing, precip_events

In [10]:
def calculate_mass_balance_over_events(basin_0str, spex, forcing, precip_events):
    
    basin_int = int(basin_0str)
    start_date = spex["start_date"]
    end_date = spex["end_date"]
    tsplt = spex["tsplt"]
    models = spex['models']
    flows = spex['flows']
    
    mass_balance_over_events = pd.DataFrame(columns=["event",
                                                     "event_date",
                                                     "total_precip", 
                                                     "total_obs", 
                                                     "total_lstm",
                                                     "total_mc",
                                                     "total_sac",
                                                     "runoff_ratio",
                                                     "AME_lstm",
                                                     "PME_lstm",
                                                     "NME_lstm",
                                                     "AME_mc",
                                                     "PME_mc",
                                                     "NME_mc",
                                                     "AME_sac",
                                                     "PME_sac",
                                                     "NME_sac"])
    
    #-------------------------------------------------------------------------------------------------
    # Reset the total mass to zero for this basin    
    cumulative_mass = {event:{flow:[0] for flow in flows} for event in precip_events}
    for event in precip_events:
        cumulative_mass[event]['precip'] = [0]
    total_mass[forcing_type][tsplt][basin_0str] = {event:{flow:0 for flow in flows} for event in precip_events}
    #-------------------------------------------------------------------------------------------------


    #-------------------------------------------------------------------------------------------------
    # We need the basin area to convert to CFS, to interpolate the RI from LPIII
    basin_area = pd_attributes.loc[basin_int, 'area_geospa_fabric']
    basin_str = str(basin_int).zfill(8)
    #-------------------------------------------------------------------------------------------------

    #-------------------------------------------------------------------------------------------------
    # Make dictionary with all the flows
    flow_mm = {}
    #-------------------------------------------------------------------------------------------------
    if tsplt == 'time_split2' and forcing_type == 'nldas':
        # Get the NWM data for this basin in an xarray dataset.
        xr_nwm = xr.DataArray(nwm_results[basin_0str]['streamflow'].values,
                 coords=[nwm_results[basin_0str]['streamflow'].index],
                 dims=['datetime'])
        # convert from CFS to mm/day
        # fm3/s * 3600 sec/hour * 24 hour/day / (m2 * mm/m)
        flow_mm['nwm'] = xr_nwm.loc[start_date:end_date]*3600*24/(basin_area*1000)
    #-------------------------------------------------------------------------------------------------
    # Standard LSTM 
    if tsplt == 'time_split1':
        xrr = lstm_results_time_split1[forcing_type][basin_0str]['1D']['xr']['QObs(mm/d)_sim'].loc[start_date:end_date]
        flow_mm['lstm'] = pd.DataFrame(data=xrr.values,index=xrr.datetime.values)
    if tsplt == 'time_split2':
        xrr = lstm_results_time_split2[forcing_type][basin_0str]['1D']['xr']['QObs(mm/d)_sim'].loc[start_date:end_date]
        flow_mm['lstm'] = pd.DataFrame(data=xrr.values,index=xrr.date.values)
    #-------------------------------------------------------------------------------------------------
    # Mass-conserving LSTM data trained on all years
    if tsplt == 'time_split1':
        xrr = mclstm_results_time_split1[forcing_type][basin_0str]['1D']['xr']['QObs(mm/d)_sim'].loc[start_date:end_date]
        flow_mm['mc'] = pd.DataFrame(data=xrr.values,index=xrr.datetime.values)
    if tsplt == 'time_split2':
        xrr = mclstm_results_time_split2[forcing_type][basin_0str]['1D']['xr']['QObs(mm/d)_sim'].loc[start_date:end_date]
        flow_mm['mc'] = pd.DataFrame(data=xrr.values,index=xrr.date.values)
    #-------------------------------------------------------------------------------------------------
    # SACSMA Mean
    if tsplt == 'time_split1':
        df = sacsma_results_time_split1[forcing_type][basin_0str].loc[start_date:end_date]
    if tsplt == 'time_split2':
        df = sacsma_results_time_split2[forcing_type][basin_0str].loc[start_date:end_date]
    flow_mm['sac'] = df
    #-------------------------------------------------------------------------------------------------
    # OBSERVATIONS
    if tsplt == 'time_split1':
        xrr = mclstm_results_time_split1[forcing_type][basin_0str]['1D']['xr']['QObs(mm/d)_obs'].loc[start_date:end_date]
        flow_mm['obs'] = pd.DataFrame(data=xrr.values,index=xrr.datetime.values)
    if tsplt == 'time_split2':
        xrr = mclstm_results_time_split2[forcing_type][basin_0str]['1D']['xr']['QObs(mm/d)_obs'].loc[start_date:end_date]
        flow_mm['obs'] = pd.DataFrame(data=xrr.values,index=xrr.date.values)

    #-------------------------------------------------------------------------------------------------
    # Make sure we are in a time period that all the flow members have values
    # If there is missin observations than we can't compare the mass of the observed with simulaitons
    skip_basin_because_missing_obs = False
    if tsplt == 'time_split1':
        obs_temp = mclstm_results_time_split1[forcing_type][basin_0str]['1D']['xr']['QObs(mm/d)_obs'].datetime
    if tsplt == 'time_split2':
        obs_temp = mclstm_results_time_split2[forcing_type][basin_0str]['1D']['xr']['QObs(mm/d)_obs'].date

#     for event in precip_events:
        
#         for d in obs_temp[event-3:event+5]:
            
#             imass=1
            
#             if d.values < start_date:
#                 continue
#             if d.values > end_date:
#                 break
#             if np.isnan(flow_mm['obs'].loc[d.values].values[0]):
#                 skip_basin_because_missing_obs = True
#                 break
#             else:
#                 #-----------------------------------------------------------------------------------------
#                 # Keep track of the cumulative mass and add it to the list
#                 cumulative_mass[event]['precip'].append(forcing[precip_column_map[forcing_type]].loc[d.values] + \
#                                                  cumulative_mass[event]['precip'][imass-1])

#                 cumulative_mass[event]['obs'].append(flow_mm['obs'].loc[d.values].values[0] + \
#                                               cumulative_mass[event]['obs'][imass-1])

#                 if tsplt == 'time_split2' and forcing_type == 'nldas':
#                     cumulative_mass[event]['nwm'].append(flow_mm['nwm'].loc[d.values].values + \
#                                               cumulative_mass[event]['obs'][imass-1])

#                 if tsplt == 'time_split2' and forcing_type == 'nldas':
#                     cumulative_mass[event]['nwm'].append(flow_mm['nwm'].loc[d.values].values + \
#                                                   cumulative_mass[event]['nwm'][imass-1])

#                 cumulative_mass[event]['lstm'].append(flow_mm['lstm'].loc[d.values].values[0] + \
#                                                cumulative_mass[event]['lstm'][imass-1])

#                 cumulative_mass[event]['mc'].append(flow_mm['mc'].loc[d.values].values[0] + \
#                                              cumulative_mass[event]['mc'][imass-1])

#                 cumulative_mass[event]['sac'].append(flow_mm['sac'].loc[d.values] + \
#                                               cumulative_mass[event]['sac'][imass-1])
#                 imass+=1
#                 #----------------------------------------------------------------------------------------
                
    #-------------------------------------------------------------------------------------------------
    #################    DO MASS PER EVENT
    #-------------------------------------------------------------------------------------------------
    mass_basin_list[tsplt].append(basin_0str)

    for event in precip_events:
        
        mass_balance_over_events.loc[event,'event'] = event
        
        sevd = event-3# StartEventDate
        eevd = event+5# EndEventDate
        
        for flow in flows:
            
            total_mass[forcing_type][tsplt][basin_0str][event][flow] = np.nansum(flow_mm[flow].iloc[sevd:eevd])
                        
        ts = pd.to_datetime(str(forcing[precip_column_map[forcing_type]].index.values[event])) 
        d = ts.strftime('%Y.%m.%d')
        mass_balance_over_events.loc[event,'event_date'] = d
        mass_balance_over_events.loc[event,'total_precip'] = \
             np.sum(forcing[precip_column_map[forcing_type]].values[sevd:eevd])
        mass_balance_over_events.loc[event,'total_obs'] = \
            total_mass[forcing_type][tsplt][basin_0str][event]['obs']
        mass_balance_over_events.loc[event,'total_lstm'] = \
            total_mass[forcing_type][tsplt][basin_0str][event]['lstm']
        mass_balance_over_events.loc[event,'total_mc'] = \
            total_mass[forcing_type][tsplt][basin_0str][event]['mc']
        mass_balance_over_events.loc[event,'total_sac'] = \
            total_mass[forcing_type][tsplt][basin_0str][event]['sac']
        mass_balance_over_events.loc[event,'runoff_ratio'] = \
            mass_balance_over_events.loc[event,'total_obs'] / \
            mass_balance_over_events.loc[event,'total_precip']


        for model in models:
            
            _obs = total_mass[forcing_type][tsplt][basin_0str][event]['obs']
            _sim = total_mass[forcing_type][tsplt][basin_0str][event][model]
            
            mass_balance_over_events.loc[event,f'AME_{model}'] = np.abs(_sim - _obs) / _obs
            if (_sim - _obs) > 0:
                mass_balance_over_events.loc[event,f'PME_{model}'] = (_sim - _obs) / _obs
                mass_balance_over_events.loc[event,f'NME_{model}'] = 0
            else:
                mass_balance_over_events.loc[event,f'NME_{model}'] = (_sim - _obs) / _obs
                mass_balance_over_events.loc[event,f'PME_{model}'] = 0

    return mass_balance_over_events



In [11]:
##########################################################################################################
# IDENTIFY EVENTS WITH PRECIP OVER 10mm
# THEN DO THE MASS BALANCE CALC OVER SOME WINDOW
##########################################################################################################
forcing_products = ['nldas','daymet']

file_name_map = {'nldas':'nldas', 'daymet':'cida'}
precip_column_map = {'nldas':'PRCP(mm/day)', 'daymet':'prcp(mm/day)'}

# total_mass_error = {event:{forcing_type:{time_split:{'absolute':{flow:[] for flow in flows},
#       'positive':{flow:[] for flow in flows},
#       'negative':{flow:[] for flow in flows}} for time_split in ['time_split1', 'time_split2']} for \
#        forcing_type in forcing_products} for event in precip_events}

# for err_type in ['absolute','positive', 'negative']:
#     total_mass_error['nldas']['time_split2'][err_type]['nwm']=[]

# cumulative_mass_all = {forcing_type:{time_split:{} for time_split in ['time_split1', 'time_split2']} for \
#                        forcing_type in forcing_products}

total_mass = {forcing_type:{time_split:{} for time_split in ['time_split1', 'time_split2']} for \
                       forcing_type in forcing_products}

mass_basin_list={}

for tsplt in ['time_split1', 'time_split2']:
    print('tsplt', tsplt)
    for forcing_type in forcing_products:

        print('forcing_type ',forcing_type)

        mass_basin_list[tsplt] = []
        forcing_dir = '/home/NearingLab/data/camels_data/basin_dataset_public_v1p2'+\
            '/basin_mean_forcing/{}_all_basins_in_one_directory/'.format(forcing_type)

        spex = get_specifications(tsplt, forcing_type)

        first_basin = True

        for basin_0str in spex["basin_list"]:
            
            print(basin_0str)
            
            forcing, precip_events = load_forcing_and_identify_events(tsplt, 
                                                                      basin_0str, 
                                                                      forcing_dir, 
                                                                      file_name_map, 
                                                                      forcing_type)
            
            mass_balance_over_events = calculate_mass_balance_over_events(basin_0str, 
                                                                          spex, 
                                                                          forcing, 
                                                                          precip_events)
            
            print("LSTM absolute mass error", np.mean(mass_balance_over_events.loc[:,'AME_lstm']))
            print("MC-LSTM absolute mass error", np.mean(mass_balance_over_events.loc[:,'AME_mc']))
            print("SacSMA absolute mass error", np.mean(mass_balance_over_events.loc[:,'AME_sac']))
            
            break

        break
        
    break



tsplt time_split1
forcing_type  nldas
01022500
Number of precipitation events 62
LSTM absolute mass error 0.34070513325352825
MC-LSTM absolute mass error 0.3318684793287708
SacSMA absolute mass error 0.41735557471304946


In [12]:
mass_balance_over_events

Unnamed: 0,event,event_date,total_precip,total_obs,total_lstm,total_mc,total_sac,runoff_ratio,AME_lstm,PME_lstm,NME_lstm,AME_mc,PME_mc,NME_mc,AME_sac,PME_sac,NME_sac
11,11,1989.10.11,23.52,6.47784,11.594068,10.486248,11.78696,0.275418,0.789804,0.789804,0,0.618788,0.618788,0,0.819582,0.819582,0
68,68,1989.12.07,10.17,7.835024,12.579723,13.387175,10.741683,0.770405,0.605576,0.605576,0,0.708632,0.708632,0,0.370983,0.370983,0
77,77,1989.12.16,15.24,6.094832,8.719742,8.201403,8.388856,0.399923,0.430678,0.430678,0,0.345632,0.345632,0,0.376388,0.376388,0
97,97,1990.01.05,2.27,9.833329,12.827887,11.774343,10.516388,4.331863,0.304531,0.304531,0,0.197391,0.197391,0,0.069464,0.069464,0
127,127,1990.02.04,10.52,15.853225,21.448011,18.991274,11.648147,1.506961,0.352912,0.352912,0,0.197944,0.197944,0,0.265251,0,-0.265251
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3413,3413,1999.02.03,37.05,44.895271,34.298752,30.877831,33.15684,1.211748,0.236028,0,-0.236028,0.312225,0,-0.312225,0.261463,0,-0.261463
3423,3423,1999.02.13,6.7,18.438534,19.184235,14.863183,20.114732,2.75202,0.040443,0.040443,0,0.193906,0,-0.193906,0.090907,0.090907,0
3483,3483,1999.04.14,2.87,17.564274,22.607628,21.092661,9.89182,6.119956,0.287137,0.287137,0,0.200884,0.200884,0,0.436822,0,-0.436822
3495,3495,1999.04.26,2.67,12.755852,11.514612,11.169155,7.183783,4.777473,0.097307,0,-0.097307,0.12439,0,-0.12439,0.436824,0,-0.436824
