In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from scipy import signal
import xesmf as xe
 
from mlprecip_utils import *
from mlprecip_models import *
from mlprecip_xai import *
from mlprecip_plot import *

import warnings
warnings.filterwarnings("ignore")

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
varname='precip'
model_types=['cnn_cat'] 

# Define Winter and Summer Months
winter=[12,1,2]
summer=[6,7,8]

### Read in train and test data

In [3]:
ds_features_tval=xr.open_dataset('../data/cnn/features_trainval.nc')
ds_target_tval=xr.open_dataset('../data/cnn/target_trainval.nc')

### Main Program to Train Models

In [None]:
# Loop over seasons
for seas,slabel,seas_abbrv in zip([winter,summer],['Winter','Summer'],['DJF','JJA']):
    
    print(slabel)
                    
    # Select season from target precip anomalies
    ds_i=ds_target_tval.sel(time=ds_target_tval['time.month'].isin(seas)) 
    
    # Make sure we have the same times for target and features
    ds_t,ds_f=xr.align(ds_i.dropna(dim='time'),
                       ds_features_tval.dropna(dim='time'),
                       join='inner')

    # Subtract the median to ensure data is centered and classes are equal
    median=np.percentile(ds_t[varname],50)
    ds_t[varname]=ds_t[varname]-median
     
    #----- TRAIN MODELS  -----------#
    # Define number of models to fit
    nmodels=50
     
    # Loop over Model Types    
    for m_function in model_types:
        print(m_function)
        fname='../data/cnn_test/seus.median.'+seas_abbrv+'.'+m_function
        ofname='../data/cnn_test/model_output.'+seas_abbrv+'.'+m_function
        trainCNN(m_function,ds_f,ds_t,varname,nmodels,fname=fname,ofname=ofname)
        
        

Winter
cnn_cat
Upper Cat:  1715
Lower Cat:  1714
Check Features and Target Dimensions
Features (X):  (3429, 37, 92, 6)
Target (Y):  (3429, 2)
Training Size:  3086
Validation Size:  343
Epoch 00067: early stopping
Training set accuracy score: 0.7952041462382479
Validation set accuracy score: 0.746355694167468
Validation ROC AUC score: 0.8474221856574797
Epoch 00073: early stopping
Training set accuracy score: 0.7440051895434101
Validation set accuracy score: 0.6967930016121433
Validation ROC AUC score: 0.8236322501028384
Epoch 00064: early stopping
Training set accuracy score: 0.7556707753195673
Validation set accuracy score: 0.6938775529319274
Validation ROC AUC score: 0.8381667352255588
Epoch 00068: early stopping
Training set accuracy score: 0.7926117948564861
Validation set accuracy score: 0.7346938752919523
Validation ROC AUC score: 0.8489304812834224
Epoch 00064: early stopping
Training set accuracy score: 0.7854828266493314
Validation set accuracy score: 0.7172011675709539
Valida