In [8]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from scipy import signal
import xesmf as xe
from keras.models import load_model
 
from mlprecip_utils import *
from mlprecip_models import *
from mlprecip_xai import *
from mlprecip_plot import *

import warnings
warnings.filterwarnings("ignore")

In [10]:
varname='precip'
model_type='cnn_cat' 
nmodels=1
model_path='../data/cnn_test/seus.median'
pad_length=10
winter=[12,1,2]
summer=[6,7,8]
cat_labels=['Negative','Positive']

### Read in train and test data

In [11]:
ds_features_tval=xr.open_dataset('../data/cnn/features_trainval.nc')
ds_target_tval=xr.open_dataset('../data/cnn/target_trainval.nc')
ds_features_test=xr.open_dataset('../data/cnn/features_test.nc')
ds_target_test=xr.open_dataset('../data/cnn/target_test.nc')

In [12]:
ds_features=[ds_features_tval,ds_features_test]
ds_target=[ds_target_tval,ds_target_test]

### Main Program to Validate Models

In [13]:
# Loop over seasons
#for seas,slabel,seas_abbrv in zip([winter,summer],['Winter','Summer'],['DJF','JJA']):
for seas,slabel,seas_abbrv in zip([summer],['Summer'],['JJA']):

    print(slabel)
    
    # Loop over Train-Val and Test
    for (ds_f, ds_t,label) in zip(ds_features, ds_target,['Train-Val','Test']):
        
        print(label)
        
        # Make sure we have the same times for target and features
        ds_t,ds_f=xr.align(ds_t.dropna(dim='time'),
                           ds_f.dropna(dim='time'),
                           join='inner')
        
        # Select season from target precip anomalies and features
        ds_f_seas=ds_f.sel(time=ds_f['time.month'].isin(seas)) 
        ds_t_seas=ds_t.sel(time=ds_t['time.month'].isin(seas)) 
    
        # Subtract the median to ensure data is centered and classes are equal
        median=np.percentile(ds_t_seas[varname],50)
        ds_t_seas[varname]=ds_t_seas[varname]-median
       
        # Create X Features input
        feature_vars=list(ds_f_seas.keys())
        da_list=[]
        for v in feature_vars:
            da_list.append(ds_f_seas[v])
            X=xr.combine_nested(da_list,concat_dim='var') 
            X=(X.transpose('time','lat','lon','var')).values
            X=xr.where(X!=0,(X-np.nanmean(X,axis=0))/np.nanstd(X,axis=0),0.0)
            X_pad=np.pad(X,((0,0),(0,0),(pad_length,pad_length),(0,0)),'wrap')

        # One Hot Encode Target (Y)
        Y_ohe=make_ohe_thresh_med(ds_t_seas[varname])
    
        print('Check Features and Target Dimensions')
        print('Features (X): ',X_pad.shape)
        print('Target (Y): ',Y_ohe.shape)
        
        # Loop over Models
        for imodel in range(nmodels):
            
            model_infname=model_path+'.'+seas_abbrv+'.'+model_type+'.'+str(imodel)+'.h5'
            print(model_infname)
            model=load_model(model_infname)
            print(model)
            #model.summary()
            
            # Accuracy Score
            score=model.evaluate(X_pad,Y_ohe) 
            print("%s: %.2f%%" % (model.metrics_names[1], score[1]*100))
            #print('Accuracy: ',str(history))
            
            # Predictions
            Yprobs=model.predict(X_pad)
            Ypred=np.argmax(Yprobs, axis = 1)
            
            # Classification Report
            #print(classification_report(np.argmax(Y_ohe,axis=1), Ypred))
            
            # Confusion Matrix
            #cm = confusion_matrix(np.argmax(Y_ohe,axis=1), Ypred)
            #disp = ConfusionMatrixDisplay(confusion_matrix=cm)
            #disp.plot()
            
            # Calculate LRP (TO-DO: put this into and return ds_lrp)
            rules=['lrp.alpha_1_beta_0']
            a=calcLRP(model,X_pad.reshape(X_pad.shape[0],
                                          X_pad.shape[1],
                                          X_pad.shape[2],
                                          X_pad.shape[3]),rules=rules)
            b=np.asarray(a)[:,:,:,pad_length:-pad_length,:]
            del a
        
            # Put all model output information into a Dataset 
            # to be written to a netcdf file (Function?)
            ds_lrp=xr.DataArray(b,
                                coords={'rules':rules,
                                        'time':ds_f_seas['time'],
                                        'lat':ds_f_seas['lat'],
                                        'lon':ds_f_seas['lon'],
                                        'var':list(ds_f_seas.keys())},
                                        dims=['rules','time','lat','lon','var']).to_dataset(name='lrp') 
            del b
        
            ds_pred=xr.DataArray(Ypred,coords={'time':ds_f_seas['time']},
                                 dims=['time']).to_dataset(name='pred')
    
            ds_probs=xr.DataArray(Yprobs,coords={'time':ds_f_seas['time'],
                                                 'cat':cat_labels},
                                        dims=['time','cat']).to_dataset(name='probs')
            ds_acc=xr.DataArray(score[1],
                                coords={'model':[imodel]},
                                dims=['model']).to_dataset(name='acc')    
            ds_verif=xr.DataArray(np.argmax(Y_ohe,axis=1),
                                  coords={'time':ds_f_seas['time']},
                                  dims=['time']).to_dataset(name='verif')
          
            ds=xr.merge([ds_lrp,ds_pred,ds_verif,ds_probs,ds_acc])
            
            del ds_lrp,ds_pred,ds_verif,ds_probs,ds_acc
            
            #model_ofname='../data/cnn_test/model_validate.'+label+'.'+seas_abbrv+'.'+model_type+'.'+str(imodel)+'.nc'
            #ds.to_netcdf(model_ofname)

Summer
Train-Val
Upper Cat:  1748
Lower Cat:  1748
Check Features and Target Dimensions
Features (X):  (3496, 37, 92, 6)
Target (Y):  (3496, 2)
../data/cnn_test/seus.median.JJA.cnn_cat.0.h5
<keras.engine.training.Model object at 0x7fb17a6f66d0>
acc: 68.48%
Test
Upper Cat:  46
Lower Cat:  46
Check Features and Target Dimensions
Features (X):  (92, 37, 92, 6)
Target (Y):  (92, 2)
../data/cnn_test/seus.median.JJA.cnn_cat.0.h5
<keras.engine.training.Model object at 0x7fb16c2dc350>
acc: 50.00%
