In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from scipy import signal
import xesmf as xe
from keras.models import load_model
 
from mlprecip_utils import *
from mlprecip_models import *
from mlprecip_xai import *
from mlprecip_plot import *

import warnings
warnings.filterwarnings("ignore")

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
target_dataset='era5-precip'
#target_dataset='era-interim-precip'

varname='precip'
model_type='logmodel_med'
nmodels=100
model_path='../data/cnn/'+target_dataset+'/models/seus.median'
pad_length=10
winter=[12,1,2]
summer=[6,7,8]
cat_labels=['Negative','Positive']

### Read in train and test data

In [3]:
ds_features_tval=xr.open_dataset('../data/cnn/features_trainval.nc')
ds_target_tval=xr.open_dataset('../data/cnn/'+target_dataset+'/target_trainval.nc')
ds_features_test=xr.open_dataset('../data/cnn/features_test.nc')
ds_target_test=xr.open_dataset('../data/cnn/'+target_dataset+'/target_test.nc')

In [4]:
ds_features=[ds_features_tval,ds_features_test]
ds_target=[ds_target_tval,ds_target_test]

In [5]:
ds_features

[<xarray.Dataset>
 Dimensions:    (time: 55520, lat: 37, lon: 72)
 Coordinates:
     level      float32 ...
   * lon        (lon) int64 0 5 10 15 20 25 30 35 ... 325 330 335 340 345 350 355
   * lat        (lat) int64 -90 -85 -80 -75 -70 -65 -60 ... 60 65 70 75 80 85 90
   * time       (time) datetime64[ns] 1979-01-01 ... 2016-12-31T18:00:00
     dayofyear  (time) int64 ...
 Data variables:
     z500       (time, lat, lon) float64 ...
     z850       (time, lat, lon) float64 ...
     u200       (time, lat, lon) float64 ...
     u850       (time, lat, lon) float64 ...
     olr        (lat, time, lon) float64 ...
     sst        (time, lat, lon) float64 ...,
 <xarray.Dataset>
 Dimensions:    (time: 3892, lat: 37, lon: 72)
 Coordinates:
     level      float32 ...
   * lon        (lon) int64 0 5 10 15 20 25 30 35 ... 325 330 335 340 345 350 355
   * lat        (lat) int64 -90 -85 -80 -75 -70 -65 -60 ... 60 65 70 75 80 85 90
   * time       (time) datetime64[ns] 2017-01-01 ... 2019-08-31T1

### Main Program to Validate Models

In [6]:
# Loop over seasons
#for seas,slabel,seas_abbrv in zip([winter,summer],['Winter','Summer'],['DJF','JJA']):
for seas,slabel,seas_abbrv in zip([winter,summer],['Summer'],['JJA']):

    print(slabel)
    
    # Loop over Train-Val and Test
    for (ds_f, ds_t,label) in zip(ds_features, ds_target,['Train-Val','Test']):

        ds_f_subset=ds_f
        
        
        # Make sure we have the same times for target and features
        ds_t,ds_f=xr.align(ds_t.dropna(dim='time'),
                           ds_f_subset.dropna(dim='time'),
                           join='inner')

        
        # Make sure we have the same times for target and features
        ds_t,ds_f=xr.align(ds_t.dropna(dim='time'),
                           ds_f.dropna(dim='time'),
                           join='inner')
        
        # Select season from target precip anomalies and features
        ds_f_seas=ds_f.sel(time=ds_f['time.month'].isin(seas)) 
        ds_t_seas=ds_t.sel(time=ds_t['time.month'].isin(seas)) 
        
        # Subtract the median to ensure data is centered and classes are equal
        median=np.percentile(ds_t_seas[varname],50)
        ds_t_seas[varname]=ds_t_seas[varname]-median
       
        # Create X Features input
        feature_vars=list(ds_f_seas.keys())
        X_pad=ds_f_seas.to_stacked_array('features',sample_dims=['time']).values
        
        # One Hot Encode Target (Y)
        Y_ohe=make_ohe_thresh_med(ds_t_seas[varname])
    
        print('Check Features and Target Dimensions')
        print('Features (X): ',X_pad.shape)
        print('Target (Y): ',Y_ohe.shape)
        
        # Loop over Models
        for imodel in range(0,nmodels):
            
            model_infname=model_path+'.'+seas_abbrv+'.'+model_type+'.'+str(imodel)+'.h5'
            print(model_infname)
            model=load_model(model_infname)
            print(model)
            #model.summary()
            
            # Accuracy Score
            score=model.evaluate(X_pad,Y_ohe) 
            print("%s: %.2f%%" % (model.metrics_names[1], score[1]*100))
            
            # Predictions
            Yprobs=model.predict(X_pad)
            Ypred=np.argmax(Yprobs, axis = 1)
            
            ds_pred=xr.DataArray(Ypred,coords={'time':ds_f_seas['time']},
                                 dims=['time']).to_dataset(name='pred')
    
            ds_probs=xr.DataArray(Yprobs,coords={'time':ds_f_seas['time'],
                                                 'cat':cat_labels},
                                        dims=['time','cat']).to_dataset(name='probs')
            ds_acc=xr.DataArray(score[1],
                                coords={'model':[imodel]},
                                dims=['model']).to_dataset(name='acc')    
            ds_verif=xr.DataArray(np.argmax(Y_ohe,axis=1),
                                  coords={'time':ds_f_seas['time']},
                                  dims=['time']).to_dataset(name='verif')
          
            #ds=xr.merge([ds_lrp,ds_pred,ds_verif,ds_probs,ds_acc])
            ds=xr.merge([ds_pred,ds_verif,ds_probs,ds_acc])
            
            #del ds_lrp,ds_pred,ds_verif,ds_probs,ds_acc
            del ds_pred,ds_verif,ds_probs,ds_acc
            
            model_ofname='../data/cnn/'+target_dataset+'/validate/model_validate.'+label+'.'+seas_abbrv+'.'+model_type+'.'+str(imodel)+'.nc'
            ds.to_netcdf(model_ofname)

Summer
Upper Cat:  1715
Lower Cat:  1715
Check Features and Target Dimensions
Features (X):  (3430, 15984)
Target (Y):  (3430, 2)
../data/cnn/era5-precip/models/seus.median.JJA.logmodel_med.0.h5
<keras.engine.sequential.Sequential object at 0x7f005c514410>
acc: 64.37%
../data/cnn/era5-precip/models/seus.median.JJA.logmodel_med.1.h5
<keras.engine.sequential.Sequential object at 0x7f00776fda50>
acc: 61.43%
../data/cnn/era5-precip/models/seus.median.JJA.logmodel_med.2.h5
<keras.engine.sequential.Sequential object at 0x7f00777e4e90>
acc: 62.39%
../data/cnn/era5-precip/models/seus.median.JJA.logmodel_med.3.h5
<keras.engine.sequential.Sequential object at 0x7f007790b850>
acc: 61.43%
../data/cnn/era5-precip/models/seus.median.JJA.logmodel_med.4.h5
<keras.engine.sequential.Sequential object at 0x7f005771aa50>
acc: 61.28%
../data/cnn/era5-precip/models/seus.median.JJA.logmodel_med.5.h5
<keras.engine.sequential.Sequential object at 0x7f00572dda90>
acc: 63.70%
../data/cnn/era5-precip/models/seus.