In [2]:
import numpy as np
import xarray as xr
import pandas as pd
import datetime
from netCDF4 import Dataset
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap
from eofs.standard import Eof
from eofs.multivariate.standard import MultivariateEof
from scipy import signal

In [3]:
# Function to read ERA5 data 
def read_era5(variable, subset=[240.0,15.0,-80.0,20.0], plot=True, SM=False, passpm=False):
    # Define ERA5 directory and file name depending on variable
    era5_dir  = '/glade/work/cab478/era5'
    era5_file = era5_dir + '/era5_' + variable + '_1980_2016.nc'
    
    # Get ERA5 variable name (which may be different than abbreviation used in file name)
    if variable=='precip':
        varname = 'tp'
    elif variable=='sm':
        varname = ['swvl1','swvl2','swvl3']
    elif variable=='t2m':
        varname = variable
    
    # Let user know which file is being read
    print(era5_file)
    
    # Define pandas date range with monthly frequency
    dates = pd.date_range(start='1/1/1980', end='12/31/2016', freq='M')
    # Exclude months not in spring/summer
    dates = dates[(dates.month>=9) | (dates.month<=3)]

    # Read ERA5 data from file
    data  = Dataset(era5_file)
    
    # Convert lat and lon arrays to Series
    lat = pd.Series(data.variables['latitude'][:])
    lon = pd.Series(data.variables['longitude'][:])
    
    # Do this if given a longitude range that passes the Prime Meridian (0 deg lon)
    if passpm:
        # Get indices of points up to the Prime Meridian
        lon_slice1 = lon[lon>subset[0]].index.tolist()
        # Get indices of points beyond the Prime Meridian
        lon_slice2 = lon[lon<subset[1]].index.tolist()
    # Do this in all other cases
    else:
        lon_slice  = lon[(lon>subset[0])&(lon<subset[1])].index.tolist()
    
    # Get indices of data within given latitude bounds
    lat_slice = lat[(lat>subset[2])&(lat<subset[3])].index.tolist()
    
    # Do this if SM data is desired
    if SM:
        if passpm:
    
            # Read in SM data for each ERA5 soil layer (top 3 layers only)
            # Concatenate data for both longitude subsets
            L1 = np.concatenate((data.variables[varname[0]][:,lat_slice,lon_slice1],
                                     data.variables[varname[0]][:,lat_slice,lon_slice2]),axis=2)
            
            L2 = np.concatenate((data.variables[varname[1]][:,lat_slice,lon_slice1],
                                     data.variables[varname[1]][:,lat_slice,lon_slice2]),axis=2)
    
            L3 = np.concatenate((data.variables[varname[2]][:,lat_slice,lon_slice1],
                                     data.variables[varname[2]][:,lat_slice,lon_slice2]),axis=2)
        # Do this in all other cases
        else:
            L1 = data.variables[varname[0]][:,lat_slice,lon_slice]
            L2 = data.variables[varname[1]][:,lat_slice,lon_slice]
            L3 = data.variables[varname[2]][:,lat_slice,lon_slice]
            
        # Set missing data to NaNs (very small SM values)
        L1[L1 < 0.0001] = np.NaN
        L2[L2 < 0.0001] = np.NaN
        L3[L3 < 0.0001] = np.NaN
        
        # Concatenate data for all layers
        all_l  = np.concatenate((L1[...,np.newaxis], L2[...,np.newaxis]), axis=3)
        all_l  = np.concatenate((all_l, L3[...,np.newaxis]), axis=3)
    
        # Create masked array before computing weighted average 
        masked_all_l = np.ma.masked_array(all_l, np.isnan(all_l))
        
        # Define array of ERA5 soil layer thicknesses
        # From https://confluence.ecmwf.int/pages/viewpage.action?pageId=56660259    
        thick = np.array([0.07,0.21,0.72])
        # Compute layer thickness weights 
        wgts  = thick/thick.sum()
        # Perform weighted average on masked array using weights
        wgt_avg = np.ma.average(masked_all_l,axis=3,weights=wgts)
        # Replace masked values with NaNs
        out     = wgt_avg.filled(np.NaN)
    
    # Do this if SM data not desired
    else:
        # Do this if given longitude values include Prime Meridian
        if passpm:
            out = np.concatenate((data.variables[varname][:,lat_slice,lon_slice1],
                                  data.variables[varname][:,lat_slice,lon_slice2]),axis=2)
        else:
            out = data.variables[varname][:,lat_slice,lon_slice]
        
        # Calculate fill value using scale factor and offset value
        new_fv  = (data.variables[varname]._FillValue*data.variables[varname].scale_factor) + \
                   data.variables[varname].add_offset
        # Calculate difference between data array and calculated fill value
        fv_diff = abs(out-new_fv)
        # If this difference is equal to or less than a threshold, set those values in data array to NaN
        out[fv_diff <= 0.01] = np.NaN
        print(new_fv)
        
    # Define xarray data arrays to hold output data and coordinate information
    # In both cases, the convention of the longitude array is changed from 0 to 360 to -180 to 180
    if passpm:
        lon_values = np.asarray(pd.concat((lon[lon_slice1]-360.,lon[lon_slice2]))) 
        array      = xr.DataArray(out, coords={'time':dates,'lat':np.array(lat[lat_slice].values), 
                                               'lon':lon_values}, dims=['time','lat','lon'])
    else:
        lon_values = np.asarray(lon[lon_slice])
        lon_values[lon_values>=180] = lon_values[lon_values>=180] - 360.
        array = xr.DataArray(out, coords={'time':dates,'lat':np.array(lat[lat_slice].values), 
                                          'lon':np.array(lon_values)}, dims=['time','lat','lon'])
        
    if plot:
        # Plot first time step to check data
        fig = plt.figure(figsize=(10, 6), edgecolor='w')
        m = Basemap(projection='cyl',
                llcrnrlat=-90, urcrnrlat=90,
                llcrnrlon=-180, urcrnrlon=180)
        lon, lat = np.meshgrid(lon_values, lat[lat_slice])
        m.pcolormesh(lon, lat, out[0,:,:],
             latlon=True, cmap='plasma')
        m.drawcoastlines()
        plt.colorbar()
        plt.show()
    
    return array