In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from scipy import signal
import xesmf as xe
from keras.models import load_model
 
from mlprecip_utils import *
from mlprecip_models import *
from mlprecip_xai import *
from mlprecip_plot import *

import warnings
warnings.filterwarnings("ignore")

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
varname='precip'
model_types=['nnmodel_med']
nmodels=100
model_path='../data/fnn/seus.median'
winter=[12,1,2]
summer=[6,7,8]
cat_labels=['Negative','Positive']

### Read in train and test data

In [3]:
ds_features_tval=xr.open_dataset('../data/fnn/features_trainval.nc')
ds_target_tval=xr.open_dataset('../data/fnn/target_trainval.nc')
ds_features_test=xr.open_dataset('../data/fnn/features_test.nc')
ds_target_test=xr.open_dataset('../data/fnn/target_test.nc')

In [4]:
ds_features=[ds_features_tval,ds_features_test]
ds_target=[ds_target_tval,ds_target_test]

### Main Program to Validate Models

In [5]:
for model_type in model_types:
    
    print(model_types)
    # Loop over seasons
    for seas,slabel,seas_abbrv in zip([winter,summer],['Winter','Summer'],['DJF','JJA']):

        print(slabel)
    
        # Loop over Train-Val and Test
        for (ds_f, ds_t,label) in zip(ds_features, ds_target,['Train-Val','Test']):
        
            print(label)
        
            # Make sure we have the same times for target and features
            ds_t,ds_f=xr.align(ds_t.dropna(dim='time'),
                               ds_f.fillna(0.0),
                               join='inner')
        
            # Select season from target precip anomalies and features
            ds_f_seas=ds_f.sel(time=ds_f['time.month'].isin(seas)) 
            ds_t_seas=ds_t.sel(time=ds_t['time.month'].isin(seas)) 
    
            # Subtract the median to ensure data is centered and classes are equal
            median=np.percentile(ds_t_seas[varname],50)
            ds_t_seas[varname]=ds_t_seas[varname]-median
       
            # Create X Features input
            feature_vars=list(ds_f_seas.keys())
            X=ds_f_seas.to_stacked_array('features',sample_dims=['time']).values

            # One Hot Encode Target (Y)
            Y_ohe=make_ohe_thresh_med(ds_t_seas[varname])
    
            print('Check Features and Target Dimensions')
            print('Features (X): ',X.shape)
            print('Target (Y): ',Y_ohe.shape)
        
            # Loop over Models
            for imodel in range(0,nmodels):
            
                model_infname=model_path+'.'+seas_abbrv+'.'+model_type+'.'+str(imodel)+'.h5'
                print(model_infname)
                model=load_model(model_infname)
            
                # Accuracy Score
                score=model.evaluate(X,Y_ohe) 
                print("%s: %.2f%%" % (model.metrics_names[1], score[1]*100))
            
                # Predictions
                Yprobs=model.predict(X)
                Ypred=np.argmax(Yprobs, axis = 1)
            
                # Calculate LRP (TO-DO: put this into and return ds_lrp)
                rules=['lrp.alpha_1_beta_0']
                a=calcLRP(model,X.reshape(X.shape[0],X.shape[1]),rules=rules)
                b=np.asarray(a)
                del a
                
                # Put all model output information into a Dataset to be written to a netcdf file 
                ds_lrp=xr.DataArray(b,
                                    coords={'rules':rules,
                                            'time':ds_f_seas['time'],
                                            'var':feature_vars},
                                    dims=['rules','time','var']).to_dataset(name='lrp')    
                del b
        
                ds_pred=xr.DataArray(Ypred,coords={'time':ds_f_seas['time']},
                                     dims=['time']).to_dataset(name='pred')
    
                ds_probs=xr.DataArray(Yprobs,coords={'time':ds_f_seas['time'],
                                                     'cat':cat_labels},
                                            dims=['time','cat']).to_dataset(name='probs')
                ds_acc=xr.DataArray(score[1],
                                    coords={'model':[imodel]},
                                    dims=['model']).to_dataset(name='acc')    
                ds_verif=xr.DataArray(np.argmax(Y_ohe,axis=1),
                                      coords={'time':ds_f_seas['time']},
                                      dims=['time']).to_dataset(name='verif')
          
                ds=xr.merge([ds_lrp,ds_pred,ds_verif,ds_probs,ds_acc])
            
                del ds_lrp,ds_pred,ds_verif,ds_probs,ds_acc
            
                model_ofname='../data/fnn_test/model_validate.'+label+'.'+seas_abbrv+'.'+model_type+'.'+str(imodel)+'.nc'
                ds.to_netcdf(model_ofname)

['nnmodel_med']
Summer
Test
Upper Cat:  92
Lower Cat:  92
Check Features and Target Dimensions
Features (X):  (184, 10)
Target (Y):  (184, 2)
../data/fnn_test/seus.median.JJA.nnmodel_med.50.h5
acc: 51.09%
../data/fnn_test/seus.median.JJA.nnmodel_med.51.h5
acc: 42.39%
../data/fnn_test/seus.median.JJA.nnmodel_med.52.h5
acc: 51.63%
../data/fnn_test/seus.median.JJA.nnmodel_med.53.h5
acc: 50.00%
../data/fnn_test/seus.median.JJA.nnmodel_med.54.h5
acc: 51.09%
../data/fnn_test/seus.median.JJA.nnmodel_med.55.h5
acc: 51.09%
../data/fnn_test/seus.median.JJA.nnmodel_med.56.h5
acc: 48.91%
../data/fnn_test/seus.median.JJA.nnmodel_med.57.h5
acc: 54.89%
../data/fnn_test/seus.median.JJA.nnmodel_med.58.h5
acc: 42.93%
../data/fnn_test/seus.median.JJA.nnmodel_med.59.h5
acc: 48.37%
../data/fnn_test/seus.median.JJA.nnmodel_med.60.h5
acc: 53.80%
../data/fnn_test/seus.median.JJA.nnmodel_med.61.h5
acc: 50.00%
../data/fnn_test/seus.median.JJA.nnmodel_med.62.h5
acc: 47.28%
../data/fnn_test/seus.median.JJA.nnmode