In [None]:
# import dependencies
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import inspect
import os
import random

from keras.layers import PReLU
import tensorflow as tf
from tensorflow.keras.layers import Input, Concatenate
from tensorflow.keras.layers import LSTM, Dense, Masking
from keras.models import Model
from keras.callbacks import ModelCheckpoint

from tensorflow.keras.layers import Lambda
import tensorflow.keras.backend as K

import random
from itertools import combinations as cb
random.seed(10)

### Define user parameters and definitions

In [None]:
# years to train over
training_years = np.arange(2004,2020)

# directory containing the tiled training data
data_direc = '/direcc/jpflug/ML_layers/WUS_tiles/'

# number of days in the data year
# keep this lower than 360
noDays = 330

# specify where to output the trained model outputs
modelOutputs_direc = data_direc+'temp/trial8/'

# specify the number of epochs to train the model over
no_epochs = 100

### General script functions

In [None]:
# SWE melt-correction function
def SWE_adjust(SWE_array,SCF_array):
    snowArr_out = SWE_array.copy()
    for i in range(snowArr_out.shape[1]):
        A = SWE_array[:,i].copy()
        A_max = np.max(A)
        A_loc = np.where(A == A_max)[0][0]
        A_orig = A[:A_loc]
        B = SCF_array[:,i].copy()
        idx = np.where(B == np.max(B))[0][0]
        B[:idx+1] = np.nan
        idx = np.where(B == 0)[0]
        if len(idx) > 0:
            idx = idx[0]
            if idx < A_loc:
                break
            else:
                A = A - A[idx]
                A[A < 0] = 0
                snowArr_out[:,i] = A
                if np.nanmax(snowArr_out) <= 0:
                    snowArr_out[:,i] = SWE_array[:,i]
                else:
                    pctDiff = A_max/np.nanmax(snowArr_out[:,i])
                    snowArr_out[A_loc:,i] = snowArr_out[A_loc:,i]*pctDiff
                    snowArr_out[:A_loc,i] = A_orig
    return snowArr_out

# fill no data and sporadic snow presence after snow disappearance
def postProcess_SCA(snowCov_array):
    # loop through the cells
    snowCov_out = snowCov_array.copy()
    for i in range(snowCov_array.shape[1]):
        A = snowCov_array[:,i].copy()
        # find valid values, and their indices
        mask = A > 0.01
        indices = np.where(mask)[0]
        # loop through the valid values
        for idx in indices:
            # make sure this value is not at the beginning or end of the array
            if idx > 0 and idx < len(A) - 1:
                try:
                    # find preceding valid value
                    starter = idx - 1
                    while A[starter] < 0:
                        starter = starter - 1
                    # find following valid value
                    ender = idx + 1
                    while A[ender] < 0:
                        ender = ender + 1
                    # print(starter,ender)
                    if (A[starter] <= 0.01) & (A[ender] <= 0.01):
                        A[idx] = 0
                except:
                    print(i,'running against the array boundary')
        # find zero values, and their indices
        mask = A == 0
        indices = np.where(mask)[0]
        # loop through the valid values
        for idx in indices:
            # make sure this value is not at the beginning or end of the array
            if idx > 0 and idx < len(A) - 1:
                try:
                    # find preceding valid value
                    starter = idx - 1
                    while A[starter] < 0:
                        starter = starter - 1
                    # find following valid value
                    ender = idx + 1
                    while A[ender] < 0:
                        ender = ender + 1
                    # print(starter,ender)
                    if (A[starter] > 0.01) & (A[ender] > 0.01):
                        A[idx] = -10000
                except:
                    continue          
        # Replace no-data values with zero from the start to the first positive index
        first_positive_index = np.argmax(A > 0.05)
        A[:first_positive_index][A[:first_positive_index] == -10000] = 0
        # Replace no-data values with zero from the last positive index to the end
        last_positive_index = len(A) - np.argmax(A[::-1] > 0.05) - 1
        A[last_positive_index + 1:][A[last_positive_index + 1:] == -10000] = 0
        snowCov_out[:,i] = A
        if np.mod(i,1000) == 0:
            print(i)
    return snowCov_out

# save/store custom loss functions
# for now, keeping only what was used in Pflug et al. (202X)
def default_mean_squared_error(y_true, y_pred):
    loss = tf.square(y_true - y_pred)
    return tf.reduce_mean(loss)

### Model preparation

In [None]:
def snowLSTM_linear_zeroBound():
    # model input layers
    input_1 = Input(shape=(noDays, 1)) 
    # masking layer, so snow cover data
    input_1_masked = Masking(mask_value=-10000.0)(input_1)
    input_2 = Input(shape=(noDays, 1)) 
    # masking layer, so snow cover data
    input_2_masked = Masking(mask_value=-10000.0)(input_2)
    input_3 = Input(shape=(noDays, 1))
    input_4 = Input(shape=(noDays, 1))
    
    # LSTM layer for the first input
    lstm_1 = LSTM(units=32)(input_1_masked)
    # LSTM layer for the second input
    lstm_2 = LSTM(units=32)(input_2_masked)
    # LSTM layer for the third input
    lstm_3 = LSTM(units=32)(input_3)
    # LSTM layer for the third input
    lstm_4 = LSTM(units=32)(input_4)
    # Concatenate the LSTM layers
    concatenated = Concatenate()([lstm_1, lstm_2, lstm_3, lstm_4])
    
    # Dense layers
    output_1 = Dense(units=noDays, activation='linear',name='output_mse1')(concatenated)
    # output_2 = Dense(units=noDays, activation='linear',name='output_mse2')(concatenated)
    output_3 = Dense(units=noDays, activation='linear',name='output_mse3')(concatenated)
    # output_1 = Dense(units=noDays, activation='relu',name='output_mse1')(concatenated)
    # output_2 = Dense(units=noDays, activation='relu',name='output_mse2')(concatenated)
    
    # output_2_zero_bound = CustomOutputLayer(name='custom_output')([output_2, input_2])
    
    # Create model
    # model = Model(inputs=[input_1, input_2, input_3, input_4], outputs=[output_1,output_2_zero_bound,output_3])
    model = Model(inputs=[input_1, input_2, input_3, input_4], outputs=[output_1,output_3])
    return model

### Load the model training data

In [None]:
# intermediate load
dsSWE = np.load(modelOutputs_direc+'dsSWE_norm.npy')
dsSCF = np.load(modelOutputs_direc+'dsSCF_norm.npy')
SCFaccum = np.load(modelOutputs_direc+'dsSCFaccum_norm.npy')
dsT = np.load(modelOutputs_direc+'dsT_norm.npy')
dsP = np.load(modelOutputs_direc+'dsP_norm.npy')
print(dsSWE.shape)

# load the pre-defined random folds
five_split = np.load(modelOutputs_direc+'FOLDSidxs.npy')
# determine combinations of the five folds
comb = cb([0,1,2,3,4], 4)
combines = list(comb)

# sanity check plotting
locc = random.randint(0,SCFaccum.shape[1])
fg,ax = plt.subplots()
ax.plot(np.arange(noDays),dsSWE[:,locc],'-k')
ax.scatter(np.arange(noDays),dsSCF[:,locc])
ax.scatter(np.arange(noDays),SCFaccum[:,locc])
ax.plot(np.arange(noDays),dsT[:,locc])
ax.plot(np.arange(noDays),dsP[:,locc])
ax.set_ylim([0,1])

# make sure models arent being carried from previous iterations
try:
    del model
except:
    print('no model')

### Train the model

In [None]:
# loop through the combinations
for i in range(4,len(combines)):
    # training indices
    train_list = np.concatenate(list(five_split[np.array(combines[i])]))
    # remaining indices
    nontraingp = np.setxor1d(combines[i], [0,1,2,3,4])
    # split the remaining indices into excluded/validation datasets
    excluded_list = five_split[nontraingp[0]][0:len(five_split[nontraingp[0]])//2]
    valid_list = five_split[nontraingp[0]][len(five_split[nontraingp[0]])//2 : len(five_split[nontraingp[0]])]
    
    # filepaths and checkpoints to save the model output to
    filepath_1 = modelOutputs_direc+'linear_bestoutput'+ str(i) + "fold" +'.hdf5'
    checkpoint = ModelCheckpoint(filepath=filepath_1, 
                                 monitor='val_loss',
                                 verbose=1, 
                                 save_best_only=True,
                                 mode='min')
    callbacks = [checkpoint]
    
    # pull in the model
    model = snowLSTM_linear_zeroBound()
    
    # perform the post-processing corrections to the snow cover data for sporadic snow presence and zeros at the beginning/end
    Barray = postProcess_SCA(SCFaccum[:,train_list])
    Carray = postProcess_SCA(SCFaccum[:,valid_list])
    Darray_SWE = SWE_adjust(dsSWE[:,train_list],Barray)
    Darray_SWE[Barray == 0] = 0
    Earray_SWE = SWE_adjust(dsSWE[:,valid_list],Carray)
    Earray_SWE[Carray == 0] = 0
    
    # run the training    
    model.compile(optimizer='adam',
                  loss={'output_mse1': default_mean_squared_error, 'output_mse3': default_mean_squared_error},
                  loss_weights={'output_mse1': 1.0, 'output_mse3': 1.0})
    model.fit([dsSCF[:,train_list].T,Barray.T,dsT[:,train_list].T,dsP[:,train_list].T],
              [dsSWE[:,train_list].T,Darray_SWE.T],epochs=no_epochs,batch_size=64,
              validation_data=([dsSCF[:,valid_list].T,Carray.T,dsT[:,valid_list].T,dsP[:,valid_list].T],
                               [dsSWE[:,valid_list].T,Earray_SWE.T]),callbacks=[checkpoint])