<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#def_pytorch" data-toc-modified-id="def_pytorch-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>def_pytorch</a></span></li><li><span><a href="#ModelSet" data-toc-modified-id="ModelSet-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>ModelSet</a></span></li><li><span><a href="#Data" data-toc-modified-id="Data-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Data</a></span></li><li><span><a href="#rescale" data-toc-modified-id="rescale-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>rescale</a></span></li><li><span><a href="#single_main" data-toc-modified-id="single_main-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>single_main</a></span></li><li><span><a href="#single_func" data-toc-modified-id="single_func-6"><span class="toc-item-num">6&nbsp;&nbsp;</span>single_func</a></span></li><li><span><a href="#apply_ufunc---Change" data-toc-modified-id="apply_ufunc---Change-7"><span class="toc-item-num">7&nbsp;&nbsp;</span>apply_ufunc - Change</a></span></li><li><span><a href="#model_id" data-toc-modified-id="model_id-8"><span class="toc-item-num">8&nbsp;&nbsp;</span>model_id</a></span></li><li><span><a href="#all_main" data-toc-modified-id="all_main-9"><span class="toc-item-num">9&nbsp;&nbsp;</span>all_main</a></span></li><li><span><a href="#simple-analyze" data-toc-modified-id="simple-analyze-10"><span class="toc-item-num">10&nbsp;&nbsp;</span>simple analyze</a></span></li><li><span><a href="#defective-run" data-toc-modified-id="defective-run-11"><span class="toc-item-num">11&nbsp;&nbsp;</span>defective run</a></span><ul class="toc-item"><li><span><a href="#ModelSet_de" data-toc-modified-id="ModelSet_de-11.1"><span class="toc-item-num">11.1&nbsp;&nbsp;</span>ModelSet_de</a></span></li><li><span><a href="#single_main_de" data-toc-modified-id="single_main_de-11.2"><span class="toc-item-num">11.2&nbsp;&nbsp;</span>single_main_de</a></span></li><li><span><a href="#single_func_de" data-toc-modified-id="single_func_de-11.3"><span class="toc-item-num">11.3&nbsp;&nbsp;</span>single_func_de</a></span></li><li><span><a href="#apply_ufunc_de" data-toc-modified-id="apply_ufunc_de-11.4"><span class="toc-item-num">11.4&nbsp;&nbsp;</span>apply_ufunc_de</a></span></li><li><span><a href="#model_id" data-toc-modified-id="model_id-11.5"><span class="toc-item-num">11.5&nbsp;&nbsp;</span>model_id</a></span></li><li><span><a href="#all_main" data-toc-modified-id="all_main-11.6"><span class="toc-item-num">11.6&nbsp;&nbsp;</span>all_main</a></span></li></ul></li></ul></div>

In [None]:
import pandas as pd
import numpy as np
import xarray as xr
import os
import sys
import time
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
from torch.nn import Module, LSTM, Linear
from torch.utils.data import DataLoader, TensorDataset
import hydroeval as he

import warnings
warnings.filterwarnings("ignore")
%matplotlib inline

Using the server's graphics card

# def_pytorch

In [None]:
class Net(Module):
    def __init__(self, config):
        super(Net, self).__init__()
        self.lstm = LSTM(input_size=config.input_size, hidden_size=config.hidden_size,
                         num_layers=config.lstm_layers, batch_first=True, dropout=config.dropout_rate)  # Define hidden layer
        self.linear = Linear(in_features=config.hidden_size, out_features=config.output_size)           # Define fully connected layer

    def forward(self, x, hidden=None):                # forward propagation function
        lstm_out, hidden = self.lstm(x, hidden)
        linear_out = self.linear(lstm_out)
        return linear_out, hidden


def train(config, train_and_valid_data):
    if config.do_train_visualized:
        import visdom
        vis = visdom.Visdom(env='def_pytorch')     # Whether to visualize

        
    global model_id
    model_name = 'model' + '_' + str(model_id) + '.pth'  # .pt and .pth are conventional formats
    
    train_X, train_Y, valid_X, valid_Y = train_and_valid_data
    train_X, train_Y = torch.from_numpy(train_X).float(), torch.from_numpy(train_Y).float()  # 转为Tensor
    train_loader = DataLoader(TensorDataset(train_X, train_Y),
                              batch_size=config.batch_size)  # Generate trainable batch data

    valid_X, valid_Y = torch.from_numpy(valid_X).float(), torch.from_numpy(valid_Y).float()
    valid_loader = DataLoader(TensorDataset(valid_X, valid_Y), batch_size=config.batch_size)

    device = torch.device("cuda:0" if config.use_cuda and torch.cuda.is_available() else "cpu")  # Decide whether to train on CPU or GPU

    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)   # Define optimizer
    criterion = torch.nn.MSELoss()  # Define loss

    valid_loss_min = float("inf")   # Define the initial positive infinity quantity
    bad_epoch = 0
    global_step = 0
    
    
#     print(train_X, train_Y)
    
    for epoch in range(config.epoch):
        model.train()  # Switch to training mode
        train_loss_array = []
        hidden_train = None
        for i, _data in enumerate(train_loader):
            _train_X, _train_Y = _data[0].to(device), _data[1].to(device)
            optimizer.zero_grad()  # Set gradient information to 0 before training
            pred_Y, hidden_train = model(_train_X, hidden_train)  # forward calculation
           
            hidden_train = None  # If training is non-continuous, reset hidden

            loss = criterion(pred_Y, _train_Y)  # Calculate loss
            loss.backward()  # Backpropagate loss
            optimizer.step()  # Update parameters with optimizer
            train_loss_array.append(loss.item())
            global_step += 1
            if config.do_train_visualized and global_step % 100 == 0:  # Displayed every 100 steps
                vis.line(X=np.array([global_step]), Y=np.array([loss.item()]), win='Train_Loss',
                         update='append' if global_step > 0 else None, name='Train', opts=dict(showlegend=True))

        # The following is an early stopping mechanism. When the model training does not improve 
        # the prediction effect of the validation set for consecutive config.patience epochs, it will stop to prevent overfitting.
        model.eval()  # Convert to prediction mode
        valid_loss_array = []
        hidden_valid = None
        for _valid_X, _valid_Y in valid_loader:
            _valid_X, _valid_Y = _valid_X.to(device), _valid_Y.to(device)
            pred_Y, hidden_valid = model(_valid_X, hidden_valid)
            hidden_valid = None
            loss = criterion(pred_Y, _valid_Y)  # The verification process only has forward calculation and no backpropagation process.
            valid_loss_array.append(loss.item())

        train_loss_cur = np.mean(train_loss_array)
        valid_loss_cur = np.mean(valid_loss_array)
#         print('epoch = ',epoch,' , train loss = ',train_loss_cur,' , valid loss = ',valid_loss_cur)

        if config.do_train_visualized:  # The first train_loss_cur is too large and is not displayed in visdom.
            vis.line(X=np.array([epoch]), Y=np.array([train_loss_cur]), win='Epoch_Loss',
                     update='append' if epoch > 0 else None, name='Train', opts=dict(showlegend=True))
            vis.line(X=np.array([epoch]), Y=np.array([valid_loss_cur]), win='Epoch_Loss',
                     update='append' if epoch > 0 else None, name='Eval', opts=dict(showlegend=True))

        if valid_loss_cur < valid_loss_min:    # Initial value is positive infinity
            valid_loss_min = valid_loss_cur    # gradually shrink
            bad_epoch = 0
            torch.save(model.state_dict(), config.model_save_path + model_name)  # Model save
            if valid_loss_cur < 1e-5:
                break
        else:
            bad_epoch += 1
            if bad_epoch >= config.patience:  # If the validation set index does not improve for consecutive epochs, 
                                              # the training will be stopped.
                break
#     print('save model as ',model_name) 


def predict(config, test_X):
    
    global model_id
    model_name = 'model' + '_' + str(model_id) + '.pth'  
    
    # Get test data
    test_X = torch.from_numpy(test_X).float()
    test_set = TensorDataset(test_X)
    test_loader = DataLoader(test_set, batch_size=config.batch_size)

    # Load model
    device = torch.device("cuda:0" if config.use_cuda and torch.cuda.is_available() else "cpu")
    model = Net(config).to(device)
    model.load_state_dict(torch.load(config.model_save_path + model_name))  # Load model parameters

    # First define a tensor to save the prediction results
    result = torch.Tensor().to(device)

    # Forecasting process
    model.eval()
    hidden_predict = None
    for _data in test_loader:
        data_X = _data[0].to(device)
        pred_X, hidden_predict = model(data_X, hidden_predict)
        cur_pred = torch.squeeze(pred_X, dim=0)
        result = torch.cat((result, cur_pred), dim=0)

    return result.detach().cpu().numpy()  # First remove the gradient information. If you want to transfer it to the cpu on the gpu,
                                          # finally return the numpy data.

# ModelSet

In [None]:
class ModelSet:

    # Data parameters

    start_feature = 0
    end_feature = 3
    start_label = 3
    end_label = 5
    delay_day = 0             # Delay prediction by a few days, for example, use today’s data to predict tomorrow’s data, delay_day = 1
                              # Here today's driving data is used to predict today's runoff, so delay_day = 0

    # Division of training test set and prediction set
    start_train_and_valid_date_relative = 0
    end_train_and_valid_date_relative = 144     # 2000-01-01 to 2013-12-31 training, # 2014-01-01 to 2017-12-31 verification
    
    start_test_date_relative = 0
    end_test_date_relative = 1212

    # Network parameters
    input_size = end_feature -start_feature
    output_size = end_label -start_label

    # some hyperparameters
    hidden_size = 150
    lstm_layers = 2
    dropout_rate = 0.4
    time_step = 50

    # training parameters
    do_train = True     # Whether to train
    do_test = True          # Verify or not

    shuffle_train_data = False
    use_cuda = True

    train_data_rate = 0.7
    valid_data_rate = 1-train_data_rate

    batch_size = 120
    learning_rate = 0.001
    epoch = 256
    patience = 5                # How many epochs should be trained? Stop if the validation set does not improve.
    random_seed = 666           # Random seeds, guaranteed to be reproducible

    # path parameters
    model_save_path = '/group1/longjs/7-river-data/model_save/'
 
    do_train_visualized = False
    do_pred_save_to_file = False

    # Create a directory
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)

# Data

In [None]:
class Data:
    def __init__(self, config , tair , prec , mass , runoff , et ):
        self.config = config
        self.tair = tair
        self.prec = prec
        self.mass = mass
        self.runoff = runoff
        self.et = et

        self.data_start_train_and_valid_num = self.config.start_train_and_valid_date_relative
        self.data_end_train_and_valid_num = self.config.end_train_and_valid_date_relative
        self.data_start_pred_num = self.config.start_test_date_relative
        self.data_stop_pred_num = self.config.end_test_date_relative     # The above times are based on the time of the original file

        self.train_and_valid_num_end = self.data_end_train_and_valid_num-self.data_start_train_and_valid_num
        self.train_num_end = int(self.train_and_valid_num_end * self.config.train_data_rate)
        self.pred_num_start = self.data_start_pred_num - self.data_start_train_and_valid_num
        self.pred_num_end = self.data_stop_pred_num - self.data_start_train_and_valid_num    # all relative to the time of extracting the data set.

        self.data_raw = pd.DataFrame({'tair':self.tair,
                                     'prec':self.prec,
                                     'mass':self.mass,
                                     'runoff':self.runoff,
                                     'et':self.et})    # Receive data and convert it
                
        self.data = self.data_raw[self.data_start_train_and_valid_num:self.data_stop_pred_num]      # Get the intercepted file
                                                                                                    # Includes training test set and prediction set
        self.mean = np.nanmean(self.data, axis=0)              # The mean and variance of the data
        self.std = np.nanstd(self.data, axis=0)
        self.norm_data = (self.data - self.mean)/self.std   # Normalization, de-dimensionalization
        self.norm_data = self.norm_data.fillna(0)

        self.start_num_in_test = 0      # Data from the first few days in the test set will be deleted because it is not enough for one time_step


    def get_train_and_valid_data(self):
        feature_data = self.norm_data.iloc[:self.train_num_end,self.config.start_feature:self.config.end_feature]
        label_data = self.norm_data.iloc[self.config.delay_day : self.config.delay_day + self.train_num_end,
                                    self.config.start_label:self.config.end_label]    
        # Use the data delayed a few days as the label. The default is the runoff data of the day.
        
        train_x = [feature_data[i:i+self.config.time_step] for i in range(self.train_num_end-self.config.time_step)]
        train_y = [label_data[i:i+self.config.time_step] for i in range(self.train_num_end-self.config.time_step)]

        train_x, train_y = np.array(train_x), np.array(train_y)
        train_x, valid_x, train_y, valid_y = train_test_split(train_x, train_y, test_size=self.config.valid_data_rate,
                                                              random_state=self.config.random_seed,
                                                              shuffle=self.config.shuffle_train_data)   # Divide training and validation sets

        return train_x, valid_x, train_y, valid_y

    def get_test_data(self):
        feature_data = self.norm_data.iloc[self.pred_num_start:self.pred_num_end,self.config.start_feature:self.config.end_feature]
        sample_interval = min(feature_data.shape[0], self.config.time_step)     # Prevent time_step from being larger than the number of test sets
        self.start_num_in_test = feature_data.shape[0] % sample_interval  # There is not enough data these days for one sample_interval
        time_step_size = feature_data.shape[0] // sample_interval

        # In the test data, each time_step row of data will be used as a sample, and the two samples are staggered by time_step rows.
        test_x = [feature_data[self.start_num_in_test+i*sample_interval : self.start_num_in_test+(i+1)*sample_interval]
                   for i in range(time_step_size)]
        return np.array(test_x)

# rescale

In [None]:
def rescale_and_evaluate(config: ModelSet, origin_data: Data,  predict_norm_data: np.ndarray):

    predict_norm_data_reshape = predict_norm_data.reshape((-1,config.output_size))   # Become two-dimensional (time, prediction items, such as runoff, etc.)
    predict_data = predict_norm_data_reshape * origin_data.std[config.start_label:config.end_label] + \
                   origin_data.mean[config.start_label:config.end_label]   # Restore data by saved mean and variance
    
    runoff = []
    et = []
    for i in range(len(predict_data)):
        runoff = np.append(runoff,predict_data[i][0])
        et = np.append(et,predict_data[i][1])
    
    if config.do_pred_save_to_file:
        pd.DataFrame(predict_data).to_csv(config.pred_save_path+'predict'+str(model_id)+'.csv') # save csv file
    
    return runoff,et

# single_main

In [None]:
defective_id = []

In [None]:
def nse_value(s, o):
    nse = he.evaluator(he.nse, s, o)
    return nse

def main(config, tair , prec , mass , runoff , et):
    np.random.seed(config.random_seed)  # Set random seeds to ensure reproducibility
    data_gainer = Data(config, tair , prec , mass , runoff , et)
    
    global defective_id
    global model_id
    
    nse_all_finall_et = -100
    nse_all_finall_runoff = -100
    nse_target_et = 0.6
    nse_target_runoff = 0.4
    nse_count = 0
    nse_count_all = 10
    loop_count = 0
  
    while ((nse_all_finall_et < nse_target_et) and (nse_all_finall_runoff < nse_target_runoff)) and (nse_count<nse_count_all): 

        if config.do_train:
            train_X, valid_X, train_Y, valid_Y = data_gainer.get_train_and_valid_data()
            train(config, [train_X, train_Y, valid_X, valid_Y])

        if config.do_test:
            test_X = data_gainer.get_test_data()
            pred_result = predict(config, test_X)      # The output here is the unrestored normalized prediction data.
            predict_temp_runoff,predict_temp_et = rescale_and_evaluate(config, data_gainer, pred_result)

        nse_all_temp_runoff = nse_value(predict_temp_runoff[0:192],data_gainer.runoff[12:204])
        nse_train_temp_runoff = nse_value(predict_temp_runoff[0:132],data_gainer.runoff[12:144])
        nse_test_temp_runoff = nse_value(predict_temp_runoff[132:192],data_gainer.runoff[144:204])
                
        nse_all_temp_et = nse_value(predict_temp_et[0:192],data_gainer.et[12:204])
        nse_train_temp_et = nse_value(predict_temp_et[0:132],data_gainer.et[12:144])
        nse_test_temp_et = nse_value(predict_temp_et[132:192],data_gainer.et[144:204])
        
        
        if loop_count == 0:
            predict_finall_et = predict_temp_et
            predict_finall_runoff = predict_temp_runoff  

            nse_all_finall_runoff = nse_all_temp_runoff
            nse_train_finall_runoff = nse_train_temp_runoff
            nse_test_finall_runoff = nse_test_temp_runoff
                        
            nse_all_finall_et = nse_all_temp_et
            nse_train_finall_et = nse_train_temp_et
            nse_test_finall_et = nse_test_temp_et
            
        else:
            if nse_all_finall_et < nse_all_temp_et:
                predict_finall_et = predict_temp_et
                predict_finall_runoff = predict_temp_runoff

                nse_all_finall_runoff = nse_all_temp_runoff
                nse_train_finall_runoff = nse_train_temp_runoff
                nse_test_finall_runoff = nse_test_temp_runoff
                                
                nse_all_finall_et = nse_all_temp_et
                nse_train_finall_et = nse_train_temp_et
                nse_test_finall_et = nse_test_temp_et

                nse_count = 0
                
        loop_count += 1
        nse_count += 1
        
        print('now loop = {} , Runoff: temp nse_all = {:.3f} , curren nse_all = {:.3f}, nse_train = {:.3f}, nse_test = {:.3f}'
              .format(loop_count , nse_all_temp_runoff[0] , nse_all_finall_runoff[0] , nse_train_finall_runoff[0] , nse_test_finall_runoff[0]))    
        print('                   ET: temp nse_all = {:.3f} , curren nse_all = {:.3f}, nse_train = {:.3f}, nse_test = {:.3f}'
              .format(nse_all_temp_et[0] , nse_all_finall_et[0] , nse_train_finall_et[0] , nse_test_finall_et[0]))    
        print(' ')
        
    if nse_count > 9:
        defective_point = 1
        defective_id.append(model_id)
    else:
        defective_point = 0
        
    return predict_finall_runoff,predict_finall_et,nse_all_finall_runoff,nse_train_finall_runoff,nse_test_finall_runoff,nse_all_finall_et,nse_train_finall_et,nse_test_finall_et,defective_point

# single_func

In [None]:
predict_var_et = []
predict_var_runoff = []
def single_func(tair , prec , mass , runoff , et):
    get_model_id()   # Counter +1
    global model_id
    global predict_var_et
    global predict_var_runoff
    
    if np.isnan(runoff[0]):
        predict_finall_et = np.full(1200,np.nan)
        predict_finall_runoff = np.full(1200,np.nan)
        
        nse_all_runoff = np.nan
        nse_train_runoff = np.nan
        nse_test_runoff = np.nan
                
        nse_all_et = np.nan
        nse_train_et = np.nan
        nse_test_et = np.nan

        defective_point = np.nan
    else:
        predict_finall_runoff,predict_finall_et,nse_all_runoff,nse_train_runoff,nse_test_runoff,nse_all_et,nse_train_et,nse_test_et,defective_point = main(ModelSet(), tair , prec , mass , runoff , et)    
        print('finish run model_{}, runoff nse = {:.3f} , et nse = {:.3f}'.format(model_id,nse_all_runoff[0],nse_all_et[0]))
        print(' ')
        print('{:=^100s}'.format('Next'))
        print(' ')
    
    predict_var_runoff = np.append(predict_var_runoff,predict_finall_runoff)
    predict_var_et = np.append(predict_var_et,predict_finall_et)
    
    return nse_all_runoff,nse_train_runoff,nse_test_runoff,nse_all_et,nse_train_et,nse_test_et,defective_point

# apply_ufunc - Change

In [None]:
def grid_func():
    
    # Read grid data
    data_path = '/group1/longjs/7-river-data/'
    tair = xr.open_dataset(data_path+'tair/'+'UKESM1_0_LL_SSP245_tas_1985_2100.nc').sel(time=slice('2000-01-01','2100-12-31'))
    prec = xr.open_dataset(data_path+'prcp/'+'UKESM1_0_LL_SSP245_tp_1985_2100.nc').sel(time=slice('2000-01-01','2100-12-31'))
    mass = xr.open_dataset(data_path+'mass/'+'massbaltot_RCP45_2000_2100.nc').sel(time=slice('2000-01-01','2100-12-31'))
    runoff = xr.open_dataset(data_path+'ET-Runoff/'+'runoff_full_time_v4.nc').sel(time=slice('2000-01-01','2100-12-31'))
    et = xr.open_dataset(data_path+'ET-Runoff/'+'evaporation_full_time_v3.nc').sel(time=slice('2000-01-01','2100-12-31'))
    
    # Calculation
    nse_all_runoff,nse_train_runoff,nse_test_runoff,nse_all_et,nse_train_et,nse_test_et,defective_point = xr.apply_ufunc(single_func,
                          tair['tair'],prec['prcp'],mass['massbaltot'],runoff['runoff'],et['evaporation'],
                          input_core_dims=[['time'],['time'],['time'],['time'],['time']],
                          output_core_dims=[[],[],[],[],[],[],[]],
                          vectorize=True)
    
    # save
    nse_all_runoff.name = 'nse_all'
    nse_train_runoff.name = 'nse_train'
    nse_test_runoff.name = 'nse_test'
    
    nse_all_et.name = 'nse_all'
    nse_train_et.name = 'nse_train'
    nse_test_et.name = 'nse_test'
    
    defective_point.name = 'defective_point'
    
    nse_all_runoff.to_netcdf(data_path+'result/'+'ukesm1_245_nse_all_runoff.nc')
    nse_train_runoff.to_netcdf(data_path+'result/'+'ukesm1_245_nse_train_runoff.nc')
    nse_test_runoff.to_netcdf(data_path+'result/'+'ukesm1_245_nse_test_runoff.nc')
       
    nse_all_et.to_netcdf(data_path+'result/'+'ukesm1_245_nse_all_et.nc')
    nse_train_et.to_netcdf(data_path+'result/'+'ukesm1_245_nse_train_et.nc')
    nse_test_et.to_netcdf(data_path+'result/'+'ukesm1_245_nse_test_et.nc')

    
    defective_point.to_netcdf(data_path+'result/'+'ukesm1_245_defective_point_et.nc')

    global predict_var_runoff
    global predict_var_et

    pred_runoff = xr.Dataset({'runoff':(['time','lat','lon'],predict_var_runoff.reshape((60,140,1200)).transpose(2,0,1))},
                      coords={'lon':np.arange(70.125,105,0.25),
                              'lat':np.arange(25.125,40,0.25),
                              'time':pd.date_range('20010101','21001231',freq='M')})
    pred_et = xr.Dataset({'evaporation':(['time','lat','lon'],predict_var_et.reshape((60,140,1200)).transpose(2,0,1))},
                      coords={'lon':np.arange(70.125,105,0.25),
                              'lat':np.arange(25.125,40,0.25),
                              'time':pd.date_range('20010101','21001231',freq='M')})
    
    pred_runoff.to_netcdf(data_path+'result/'+'ukesm1_245_predict_runoff.nc')
    pred_et.to_netcdf(data_path+'result/'+'ukesm1_245_predict_et.nc')
 
    return pred_runoff,pred_et,nse_all_runoff,nse_train_runoff,nse_test_runoff,nse_all_et,nse_train_et,nse_test_et,defective_point

# model_id

In [None]:
model_id = 0
def get_model_id():
    global model_id
    model_id += 1
    return model_id

# all_main

In [None]:
pred_runoff,pred_et,nse_all_runoff,nse_train_runoff,nse_test_runoff,nse_all_et,nse_train_et,nse_test_et,defective_point = grid_func()

# simple analyze

In [None]:
nse_train_et.plot()

In [None]:
u = 2.5*df_data.describe()['nse_all']['75%'] -df_data.describe()['nse_all']['25%']
d = 2.5*df_data.describe()['nse_all']['25%'] -1.5*df_data.describe()['nse_all']['75%']

In [None]:
np.sum(nse_all>u)
np.sum(nse_all<d)

In [None]:
defective_point.plot()

# defective run

In [None]:
import copy
defective_id_raw = copy.deepcopy(defective_id)

In [None]:
# defective_id_raw = copy.deepcopy(defective_get)

## ModelSet_de

In [None]:
class ModelSet_de:



    start_feature = 1         
    end_feature = 2
    start_label = 3
    end_label = 5
    delay_day = 0            

    start_train_and_valid_date_relative = 0
    end_train_and_valid_date_relative = 144     
    
    start_test_date_relative = 0
    end_test_date_relative = 1212


    input_size = end_feature-start_feature
    output_size = end_label-start_label


    hidden_size = 50
    lstm_layers = 2
    dropout_rate = 0.4
    time_step = 50

    do_train = True    
    do_test = True       

    shuffle_train_data = False
    use_cuda = True

    train_data_rate = 0.7
    valid_data_rate = 1-train_data_rate

    batch_size = 24
    learning_rate = 0.001
    epoch = 256
    patience = 5               
    random_seed = 666           


    model_save_path = '/group1/longjs/7-river-data/model_save/'
 
    do_train_visualized = False
    do_pred_save_to_file = False

    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)

## single_main_de

In [None]:
def nse_value(s, o):
    nse = he.evaluator(he.nse, s, o)
    return nse

def predict_de(config, tair , prec , mass , runoff , et):
    np.random.seed(config.random_seed) 
    data_gainer = Data(config, tair , prec , mass , runoff , et)
  
    test_X = data_gainer.get_test_data()
    pred_result = predict(config, test_X)     
    predict_finall_runoff,predict_finall_et = rescale_and_evaluate(config, data_gainer, pred_result)
        
    nse_all_finall_runoff = nse_value(predict_finall_runoff[0:192],data_gainer.runoff[12:204])
    nse_train_finall_runoff = nse_value(predict_finall_runoff[0:132],data_gainer.runoff[12:144])
    nse_test_finall_runoff = nse_value(predict_finall_runoff[132:192],data_gainer.runoff[144:204])
    
    nse_all_finall_et = nse_value(predict_finall_et[0:192],data_gainer.et[12:204])
    nse_train_finall_et = nse_value(predict_finall_et[0:132],data_gainer.et[12:144])
    nse_test_finall_et = nse_value(predict_finall_et[132:192],data_gainer.et[144:204])

    print('Runoff: temp nse_all = {:.3f} , curren nse_all = {:.3f}, nse_train = {:.3f}, nse_test = {:.3f}'
          .format(nse_all_finall_runoff[0] , nse_all_finall_runoff[0] , nse_train_finall_runoff[0] , nse_test_finall_runoff[0]))    
    print('    ET: temp nse_all = {:.3f} , curren nse_all = {:.3f}, nse_train = {:.3f}, nse_test = {:.3f}'
          .format(nse_all_finall_et[0] , nse_all_finall_et[0] , nse_train_finall_et[0] , nse_test_finall_et[0]))    
    print(' ')
    
    defective_point = 0
         
    return predict_finall_runoff,predict_finall_et,nse_all_finall_runoff,nse_train_finall_runoff,nse_test_finall_runoff,nse_all_finall_et,nse_train_finall_et,nse_test_finall_et,defective_point

## single_func_de

In [None]:
import copy

In [None]:
defective_id_run = copy.deepcopy(defective_id_raw)

In [None]:
defective_id = []

In [None]:
# Use in cycles, do not run repeatedly
defective_id_run = copy.deepcopy(defective_id)
defective_id = []

In [None]:
len(defective_id_run)

In [None]:
predict_var_et_de = []
predict_var_runoff_de = []

import copy
def single_func_de(tair , prec , mass , runoff , et):
    get_model_id()  
    global model_id
    global predict_var_et_de
    global predict_var_runoff_de
    global defective_id_run

    
    if model_id in defective_id_run:
        predict_finall_runoff,predict_finall_et,nse_all_runoff,nse_train_runoff,nse_test_runoff,nse_all_et,nse_train_et,nse_test_et,defective_point = main(ModelSet_de(), tair , prec , mass , runoff , et)  
        print('Run model_{}, et nse = {:.3f} , runoff nse = {:.3f}'.format(model_id,nse_all_et[0],nse_all_runoff[0]))
        print(' ')
        print('{:=^100s}'.format('Run'))
        print(' ')
    else:
        if np.isnan(runoff[0]):
            predict_finall_et = np.full(1200,np.nan)
            predict_finall_runoff = np.full(1200,np.nan)

            nse_all_et = np.nan
            nse_train_et = np.nan
            nse_test_et = np.nan

            nse_all_runoff = np.nan
            nse_train_runoff = np.nan
            nse_test_runoff = np.nan

            defective_point = np.nan
        else: 
            predict_finall_runoff,predict_finall_et,nse_all_runoff,nse_train_runoff,nse_test_runoff,nse_all_et,nse_train_et,nse_test_et,defective_point = predict_de(ModelSet(), tair , prec , mass , runoff , et)  
            print('Predict model_{}, et nse = {:.3f} , runoff nse = {:.3f}'.format(model_id,nse_all_et[0],nse_all_runoff[0]))
            print(' ')
            print('{:=^100s}'.format('Predict'))
            print(' ')
    
    predict_var_et_de = np.append(predict_var_et_de,predict_finall_et)
    predict_var_runoff_de = np.append(predict_var_runoff_de,predict_finall_runoff)
    
    return nse_all_runoff,nse_train_runoff,nse_test_runoff,nse_all_et,nse_train_et,nse_test_et,defective_point

## apply_ufunc_de

In [None]:
def grid_func_de():
    

    data_path = '/group1/longjs/7-river-data/'
    tair = xr.open_dataset(data_path+'tair/'+'UKESM1_0_LL_SSP245_tas_1985_2100.nc').sel(time=slice('2000-01-01','2100-12-31'))
    prec = xr.open_dataset(data_path+'prcp/'+'UKESM1_0_LL_SSP245_tp_1985_2100.nc').sel(time=slice('2000-01-01','2100-12-31'))
    mass = xr.open_dataset(data_path+'mass/'+'massbaltot_RCP45_2000_2100.nc').sel(time=slice('2000-01-01','2100-12-31'))
    runoff = xr.open_dataset(data_path+'ET-Runoff/'+'runoff_full_time_v3.nc').sel(time=slice('2000-01-01','2100-12-31'))
    et = xr.open_dataset(data_path+'ET-Runoff/'+'evaporation_full_time_v3.nc').sel(time=slice('2000-01-01','2100-12-31'))
    

    nse_all_runoff,nse_train_runoff,nse_test_runoff,nse_all_et,nse_train_et,nse_test_et,defective_point = xr.apply_ufunc(single_func_de,
                          tair['tair'],prec['prcp'],mass['massbaltot'],runoff['runoff'],et['evaporation'],
                          input_core_dims=[['time'],['time'],['time'],['time'],['time']],
                          output_core_dims=[[],[],[],[],[],[],[]],
                          vectorize=True)
    

    nse_all_runoff.name = 'nse_all'
    nse_train_runoff.name = 'nse_train'
    nse_test_runoff.name = 'nse_test'
    
    nse_all_et.name = 'nse_all'
    nse_train_et.name = 'nse_train'
    nse_test_et.name = 'nse_test'
    
    defective_point.name = 'defective_point'
    
    nse_all_runoff.to_netcdf(data_path+'result/'+'ukesm1_245_nse_all_runoff_de.nc')
    nse_train_runoff.to_netcdf(data_path+'result/'+'ukesm1_245_nse_train_runoff_de.nc')
    nse_test_runoff.to_netcdf(data_path+'result/'+'ukesm1_245_nse_test_runoff_de.nc')
       
    nse_all_et.to_netcdf(data_path+'result/'+'ukesm1_245_nse_all_et_de.nc')
    nse_train_et.to_netcdf(data_path+'result/'+'ukesm1_245_nse_train_et_de.nc')
    nse_test_et.to_netcdf(data_path+'result/'+'ukesm1_245_nse_test_et_de.nc')

    
    defective_point.to_netcdf(data_path+'result/'+'ukesm1_245_defective_point_et_de.nc')

    global predict_var_runoff
    global predict_var_et

    pred_runoff = xr.Dataset({'runoff':(['time','lat','lon'],predict_var_runoff.reshape((60,140,1200)).transpose(2,0,1))},
                      coords={'lon':np.arange(70.125,105,0.25),
                              'lat':np.arange(25.125,40,0.25),
                              'time':pd.date_range('20010101','21001231',freq='M')})
    pred_et = xr.Dataset({'evaporation':(['time','lat','lon'],predict_var_et.reshape((60,140,1200)).transpose(2,0,1))},
                      coords={'lon':np.arange(70.125,105,0.25),
                              'lat':np.arange(25.125,40,0.25),
                              'time':pd.date_range('20010101','21001231',freq='M')})
    
    pred_runoff.to_netcdf(data_path+'result/'+'ukesm1_245_predict_runoff_de.nc')
    pred_et.to_netcdf(data_path+'result/'+'ukesm1_245_predict_et_de.nc')
 
    return pred_runoff,pred_et,nse_all_runoff,nse_train_runoff,nse_test_runoff,nse_all_et,nse_train_et,nse_test_et,defective_point

## model_id

In [None]:
model_id = 0
def get_model_id():
    global model_id
    model_id += 1
    return model_id

## all_main

In [None]:
pred_runoff,pred_et,nse_all_runoff,nse_train_runoff,nse_test_runoff,nse_all_et,nse_train_et,nse_test_et,defective_point = grid_func_de()

In [None]:
len(defective_id_run)
len(defective_id)