In [None]:
# -*- coding: utf-8 -*-
"""
#-------------------------------------------------------------------------
#  Purpose : Read training binary set 
#  Author  : KIM MK
#  Content : 
#     1. binary file load
#     2. preprocessing
#     3. forecast
#     5. plot
#  History : 
#       Code by Aug. 13, 2018 ManKi Kim
#          - import read fortran subroutine
#        Add by Aug. 23, 2018 ManKi Kim
#          -           
#-------------------------------------------------------------------------
"""

#-------------------------------------------------------------------------
# .. Module load

#.. module
import os
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model, model_from_json
from tensorflow.keras import backend as K
import joblib
from tcn import TCN
import copy

#.. local
import sys
sys.path.insert(0, './inc')
from test_data_load import test_data_load
from test_find_stnidx import find_stn_idx
from check_missing_existence import check_missing_existence

#-------------------------------------------------------------------------
# .. Custum function

def sort_use_stnid(str):
        mlist_stn_id = str.split('_')[-1]
        mlist_stn_id = mlist_stn_id.split('.')[0]
        return int(mlist_stn_id)



def model_list(load_dir, tran_peri, each_stn_mod):
        model_list = []
        list = np.array(os.listdir(load_dir))
        for i in range(len(list)):
            if each_stn_mod == "ON":
               if list[i].find(find_ep)>0 and list[i].find(tran_peri[0])>0 and list[i].find('h5')>0 and list[i].find(find_pd)>0 and list[i].find(find_ks)>0 and list[i].find(find_ns)>0 and list[i].find(find_dl)>0 and list[i].find(find_lr)>0 and list[i].find(find_nf)>0 and list[i].find(find_id)>0 :
                  model_list.append(list[i])
            else:
               if list[i].find(find_ep)>0 and list[i].find(tran_peri[0])>0 and list[i].find('h5')>0 and list[i].find(find_pd)>0 and list[i].find(find_ks)>0 and list[i].find(find_ns)>0 and list[i].find(find_dl)>0 and list[i].find(find_lr)>0 and list[i].find(find_nf)>0 :
                  model_list.append(list[i])

        model_list.sort(key=sort_use_stnid)

        #return np.array(model_list)
        return model_list



def scaler_list(load_dir, tran_peri, each_stn_mod):
        obs_scaler_list = []
        nwp_scaler_list = []
        list = np.array(os.listdir(load_dir))
        for i in range(len(list)):
            if each_stn_mod == "ON":
               if list[i].find(find_ep)>0 and list[i].find(tran_peri[0])>0 and list[i].find(find_pd)>0 and list[i].find(find_ks)>0 and list[i].find(find_ns)>0 and list[i].find(find_dl)>0 and list[i].find(find_lr)>0 and list[i].find(find_nf)>0 and list[i].find(find_id)>0 :
                  if list[i].find("obs") >= 0: obs_scaler_list.append(list[i])
                  if list[i].find("nwp") >= 0: nwp_scaler_list.append(list[i])
            else:
               if list[i].find(find_ep)>0 and list[i].find(tran_peri[0])>0 and list[i].find(find_pd)>0 and list[i].find(find_ks)>0 and list[i].find(find_ns)>0 and list[i].find(find_dl)>0 and list[i].find(find_lr)>0 and list[i].find(find_nf)>0 :
                  if list[i].find("obs") >= 0: obs_scaler_list.append(list[i])
                  if list[i].find("nwp") >= 0: nwp_scaler_list.append(list[i])

        obs_scaler_list.sort(key=sort_use_stnid)
        nwp_scaler_list.sort(key=sort_use_stnid)

        return np.array(nwp_scaler_list), np.array(obs_scaler_list)

#-------------------------------------------------------------------------
# .. Device configuration



#-------------------------------------------------------------------------
# .. Data set

utc = '00'
#exp_name = "SPD_1ST_TEST_NOOPT"
exp_name = "OP_" + utc + "UTC"
home = '/home/mankicom/STD_POOL/SHRT_GDPS/HOURLY1/MODL_DVLP/SPD/TEST'
#home = '/h3/home/nwpr/mankicom/STD_POOL/SHRT_GDPS/HOURLY1/MODL_DVLP/SPD/TEST'
load_dir = home + '/DAOU'
each_stn_mod = "OFF"



## .. var7 47108 best
stn_id = 47108
#stn_id = 47169
#stn_id = 47119
#stn_id = 47165
find_id = str(stn_id)
find_var = 'var10'
#find_pd = 'pdcausal'    # padding
find_pd = 'pdsame'    # padding
find_ks = 'ks6'      # kernel size
find_ns = 'ns1'       # nb_stacks
find_nf = 'nf87'       # nb_filters
find_dl = 'dl136'       # last dilation
find_lr = 'lr0.009'
find_ep = 'e1000'


data_dir = '../DAIN/'
#data_dir = '../DAIN_3HR_EXP8/'           # for 47119, 47165
prt_outdir = '../DAIO/' + exp_name + "/"
if os.path.exists(prt_outdir) != True: os.mkdir(prt_outdir)


mdl_dir = "/MODL/" + exp_name + "/"
scl_dir = "/SCAL/" + exp_name + "/"
#mdl_dir = "/MODL/"
#scl_dir = "/SCAL/"
#mdl_dir = "/MODL/TCNM_IMPV_EXP8_" + utc + "UTC/"
#scl_dir = "/SCAL/TCNM_IMPV_EXP8_" + utc + "UTC/"


#.. exp
tran_name = [ '-24-1605-2104' ]
tran_peri = [ '20160501'+utc+'-20210430'+utc ]
tran_peri_name = [ tran_peri[0] + tran_name[0] ]


test_name = [ '2105', '2106', '2107', '2108', '2109', '2110', '2111', '2112' ]
test_peri = [ '20210501'+utc+'-20210531'+utc+'-24-',
              '20210601'+utc+'-20210630'+utc+'-24-',
              '20210701'+utc+'-20210731'+utc+'-24-',
              '20210801'+utc+'-20210831'+utc+'-24-',
              '20210901'+utc+'-20210930'+utc+'-24-',
              '20211001'+utc+'-20211031'+utc+'-24-',
              '20211101'+utc+'-20211130'+utc+'-24-',
              '20211201'+utc+'-20211231'+utc+'-24-']
test_peri_name = [ test_peri[0] + test_name[0],
                   test_peri[1] + test_name[1],
                   test_peri[2] + test_name[2],
                   test_peri[3] + test_name[3],
                   test_peri[4] + test_name[4],
                   test_peri[5] + test_name[5],
                   test_peri[6] + test_name[6],
                   test_peri[7] + test_name[7] ]
total_test_peri = '2105'+utc+'-2112'+utc
num_his = [ 31, 30, 31, 31, 30, 31, 30, 31 ]


element = 'ALLV'

num_fct = 136
num_ele = 10

input_size = num_ele
output_size = 1

print ('# of validation month' , len(test_name))

if each_stn_mod == "ON":
   obs_stn_list = '../DABA/stn_47108.dat'
else:
   #obs_stn_list = '../DABA/stn_expand_test_2019081200.dat'
   #obs_stn_list = '../DABA/stn_opimprv_test_2019081200.dat'
   obs_stn_list = '../DABA/new_dfs_merg_station_directory_2021050100.dat'


#-------------------------------------------------------------------------
# .. Model & Scaler list load

print(load_dir+mdl_dir)
model_list = model_list(load_dir+mdl_dir, tran_peri, each_stn_mod)
nwp_scaler_list, obs_scaler_list = scaler_list(load_dir+scl_dir, tran_peri, each_stn_mod)

print("------ Check model, scaler set list")
for i in range(len(model_list)):
    print  (model_list[i], nwp_scaler_list[i], obs_scaler_list[i])
print("model_list type: ", type(model_list))

mod_stn_id = []
for i in range(len(model_list)):
    stn_id = (model_list[i].split('_'))[13]
    stn_id = (stn_id.split('.'))[0]
    mod_stn_id.append(stn_id)
mod_stn_id = list(map(int, mod_stn_id))

print(type(mod_stn_id[0]))

#-------------------------------------------------------------------------
# .. Read stn list

# .. all station
test_x_all = np.ndarray( shape=( len(num_his), 31, num_fct, len(model_list), input_size ), dtype=np.float32 )
test_y_all = np.ndarray( shape=( len(num_his), 31, num_fct, len(model_list), output_size ), dtype=np.float32 )
pred_test_all = np.ndarray( shape=( len(num_his), 31, num_fct, len(model_list), output_size ), dtype=np.float32 )


pred_test_all.fill(np.nan)
test_x_all.fill(np.nan)
test_y_all.fill(np.nan)


#for i in range(5):    # for models
for i in range(len(model_list)):    # for models
#for i in range(dev_stn_id.shape[0]):    # for some stations
#for i in [100, 200, 300 ]:    # for some stations

    #-------------------------------------------------------------------------
    # .. Run LSTM forecast 

    if len(model_list)!=1:
       run_stn_id = mod_stn_id[i]
    else:
       run_stn_id = mod_stn_id

    print ( "=================== Start {} station".format(run_stn_id) )

    # .. clear model
    K.clear_session()

    # .. Load model 
    model_name = load_dir + mdl_dir + model_list[i]
    try:
       print ("load_model: ", model_list[i])
       model = load_model(model_name, custom_objects={'TCN':TCN} )
    except:
       print ("Error: Could not load ", model_name)
       continue       
   

    # .. Load scaler
    print ("load_scaler:", nwp_scaler_list[i])
    print ("load_scaler:", obs_scaler_list[i])
    nwp_scaler = joblib.load(load_dir + scl_dir + nwp_scaler_list[i])
    obs_scaler = joblib.load(load_dir + scl_dir + obs_scaler_list[i])


    for k in range(len(test_name)):  # for days
    #for k in [4]:  # for days

        #-------------------------------------------------------------------------
        # .. Data load

        print(test_peri_name[k])

        test_x, test_y = test_data_load(data_dir, test_peri_name[k], element,
                                        input_size, output_size, num_his[k],
                                        num_fct, run_stn_id)

        test_x = np.swapaxes(test_x,0,1)
        #test_x_ori = copy.deepcopy(test_x)
        test_y = np.swapaxes(test_y,0,1)


        print ("======= Loaded data shape ")
        print ("test_x shape: ", test_x.shape)
        print ("test_y shape: ", test_y.shape)

 
        # .. load data  
        b, s, f = test_x.shape

        #---------------------------------------------------------------------
        # .. Model run

        # .. 2021.02.25 kmk
        for j in range(b):

            # .. check missing
            nwp_count = check_missing_existence(test_x[j,:,:], test_y[j,:,:]) 
            if nwp_count > 0:
               print ( "------- nwp missing count > 0, pass ", j, '  day' )
               continue

            # .. normalize
            nor_test_x = nwp_scaler.transform(test_x[j:j+1,:,:].reshape(1*s,f))          # scaler input dim=(N,n_feature)
            nor_test_y = obs_scaler.transform(test_y[j:j+1,:,:].reshape(1*s,output_size))
            nor_test_x = nor_test_x.reshape(1,s,f)
            nor_test_y = nor_test_y.reshape(1,s,output_size)

            nor_pred_test = model.predict(nor_test_x, batch_size=1)    
            inv_pred_test = obs_scaler.inverse_transform(nor_pred_test.reshape(1*s,output_size))
            inv_pred_test = inv_pred_test.reshape(s,output_size)

            # .. Data save
            test_x_all[k,j,:,i,:] = test_x[j,:,:]  # select feature 0
            test_y_all[k,j,:,i,:] = test_y[j,:,:]  # select feature 0
            pred_test_all[k,j,:,i,:] = inv_pred_test[:,:] # select feature 0


#pred_test_all = pred_test_all[:,:,2:-2,:]
#test_x_all = test_x_all[:,:,2:-2,:]
#test_y_all = test_y_all[:,:,2:-2,:]
print ("-------------- all list fcst complete")
print ("pred_test_all: ", pred_test_all.shape)
print ("test_x_all: ", test_x_all.shape)
print ("test_y_all: ", test_y_all.shape)

#print ( pred_test_all )

if each_stn_mod == "ON":
   np.savez( (prt_outdir + 'tcnm' + '_' + total_test_peri + '_' + find_ep + '_' + find_id), value=pred_test_all, stn_id=mod_stn_id )
   np.savez( (prt_outdir + 'g128' + '_' + total_test_peri + '_' + find_ep + '_' + find_id), value=test_x_all, stn_id=mod_stn_id )
   np.savez( (prt_outdir + 'tobs' + '_' + total_test_peri + '_' + find_ep + '_' + find_id), value=test_y_all, stn_id=mod_stn_id )
else:
   np.savez( (prt_outdir + 'tcnm' + '_' + total_test_peri + '_' + find_ep ), value=pred_test_all, stn_id=mod_stn_id )
   np.savez( (prt_outdir + 'g128' + '_' + total_test_peri + '_' + find_ep ), value=test_x_all, stn_id=mod_stn_id )
   np.savez( (prt_outdir + 'tobs' + '_' + total_test_peri + '_' + find_ep ), value=test_y_all, stn_id=mod_stn_id )


