In [1]:
import pandas as pd
import numpy as np
import xarray as xr
import datetime
from netCDF4 import Dataset
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap

In [6]:
# Define function to read MERRA-2 data
# "variable" argument must correspond to the file name of a particular daily average MERRA-2 file
# "identifier" argument refers to the MERRA-2 variable name (which may or may not be different)
# "resolution" refers to temporal resolution and must be either 'D' for daily or 'M' for monthly
# "subset" is the desired spatial domain for which data will be extracted: [lon1,lon2,lat1,lat2]
# "fourd" should be set to True for 4-dimensional data (time,lev,lat,lon)
# Author: Carolina Bieri (bieri2@illinois.edu)

def get_MERRA2_time_avg(variable, identifier, resolution, subset=[-120.0,30.0,-80.0,20.0], fourd=False):
    # Define location of MERRA-2 files
    MERRA2_dir = '/glade/work/cab478/MERRA/day/'
    # Get file name of individual file containing daily averaged data from 1980 to 2016
    filename   = MERRA2_dir + 'MERRA_daily_avg_' + identifier + '_1980_2016.nc'
    print(filename)
    
    # Open file
    ds = Dataset(filename)  
    
    # Define date range from 1980 to 2016
    date_rng   = pd.date_range(start='1/1/1980', end='12/31/2016', freq='D')

    # Convert lat and lon arrays to Series
    lat = pd.Series(ds.variables['lat'][:])
    lon = pd.Series(ds.variables['lon'][:])
    
    # Slice lat and lon arrays according to desired domain
    lat_slice = lat[(lat>subset[2])&(lat<subset[3])].index.tolist()
    lon_slice = lon[(lon>subset[0])&(lon<subset[1])].index.tolist()
    
    # Extract variable from file within the desired spatial domain
    if fourd:
        var = np.squeeze(ds.variables[variable][:,0,lat_slice,lon_slice])
    else:
        var = ds.variables[variable][:,lat_slice,lon_slice]
    
    # Set missing data to NaNs
    var[var==ds.variables[variable]._FillValue] = np.NaN
    
    # Define xarray Data Array to hold data and metadata
    array = xr.DataArray(var, coords={'time':date_rng,'lat':np.array(lat[lat_slice].values), 
                                    'lon':np.array(lon[lon_slice].values)}, dims=['time','lat','lon'])

    # Resample the time dimension if monthly values are desired and calculate mean values
    if resolution == 'D':
        final = array
    else:
        final = array.resample(time=resolution).mean(dim='time')
      
    # PLot first time step to check data
    fig = plt.figure(figsize=(10, 6), edgecolor='w')
    m   = Basemap(projection='cyl',
                llcrnrlat=-90, urcrnrlat=90,
                llcrnrlon=-180, urcrnrlon=180)
    lon, lat = np.meshgrid(final.coords['lon'].values,final.coords['lat'].values)
    m.pcolormesh(lon, lat, final[0,:,:],
             latlon=True, cmap='plasma')
    m.drawcoastlines()
    plt.colorbar()
    plt.show() 
        
    return final