In [1]:
import numpy as np
import xarray as xr 
import pandas as pd
import datetime
from netCDF4 import Dataset
import Nio
import Ngl
import matplotlib.pyplot as plt
from mpl_toolkits.basemap import Basemap

In [1]:
# Define function to read in CMIP6 data 
# "variable" refers to the desired CMIP6 variable
# "comp" refers to the CESM component name corresponding to the desired CMIP6 variable
# "model" refers to the CESM component which provides the desired variable as output
# "subset" refers to the desired spatial domain to be extracted
# "plev" is the vertical level to which data should be interpolated - MUST BE IN MB - right now only one level is supported
# "vinterp" should be set to True if vertical interpolation from model to pressure levels is necessary
# "SM" should be set to True if this function is being used to read in soil moisture data
# "passpm" should be set to True if the requested spatial subset includes the Prime Meridian
# Author: Carolina Bieri (bieri2@illinois.edu)

def read_cesm(variable, comp, model, subset=[240.0,360.0,-80.0,20.0], plev=500.,vinterp=False, SM=False, passpm=False,
                  run='001', plot=True):
    # Set top level directory for CMIP6 data
    CMIP_dir     = '/gpfs/fs1/collections/cdg/timeseries-cmip6/b.e21.BHIST.f09_g17.CMIP6-historical.'+ run + '/' 
    # Define file names for desired variable as well as surface pressure
    filename_var = CMIP_dir + comp + '/proc/tseries/month_1/b.e21.BHIST.f09_g17.CMIP6-historical.' + run + '.' + model + \
                    '.h0.' + variable + '.185001-201412.nc'
    filename_PS  = CMIP_dir + 'atm/proc/tseries/month_1/b.e21.BHIST.f09_g17.CMIP6-historical.' + run + \
                    '.cam.h0.PS.185001-201412.nc'
    
    print(filename_var)
    
    # Set reference pressure in mb
    P0 = 1000.
    
    # Open file containing desired variable
    ds_var   = Dataset(filename_var)
    # Define date range corresponding to data: 1/1/1850 to 12/31/2014 
    date_rng = pd.date_range(start='1/1/1850', end='12/31/2014', freq='M')
    
    # Convert lat and lon arrays to Series
    lat = pd.Series(ds_var.variables['lat'][:])
    lon = pd.Series(ds_var.variables['lon'][:])
    
    # Get lat/lon indices of spatial subset 
    # Do this if requested subset includes the Prime Meridian
    if passpm:
        # Get indices of points up to the Prime Meridian
        lon_slice1 = lon[lon>subset[0]].index.tolist()
        # Get indices of points beyond the Prime Meridian
        lon_slice2 = lon[lon<subset[1]].index.tolist()
    # Do this in all other cases
    else:
        lon_slice = lon[(lon>subset[0])&(lon<subset[1])].index.tolist()
    
    lat_slice = lat[(lat>subset[2])&(lat<subset[3])].index.tolist()
    
    # Perform vertical interpolation if vinterp = True
    if vinterp:
        # Open file containing surface pressure data
        ds_PS = Dataset(filename_PS)
        
        # Read in desired variable (must be four dims - time, lev, lat, lon)
        if passpm:
            # Extract variable using determined indices and concatenate to get a continuous spatial subset
            var = np.concatenate((ds_var.variables[variable][:,:,lat_slice,lon_slice1],
                                  ds_var.variables[variable][:,:,lat_slice,lon_slice2]),axis=3)
        else:
            # Extract variable using determined indices
            var = ds_var.variables[variable][:,:,lat_slice,lon_slice]
            
        # Set missing values to NaNs
        var[var==ds_var.variables[variable]._FillValue] = np.NaN
        
        # Read in coefficients necessary to perform vertical interpolation
        hyam  = ds_var.variables['hyam'][:]
        hybm  = ds_var.variables['hybm'][:]
    
        # Read in surface pressure data
        if passpm:
            # As above, extract variable using determined indices and concatenate 
            PS = np.concatenate((ds_PS.variables['PS'][:,lat_slice,lon_slice1], 
                                 ds_PS.variables['PS'][:,lat_slice,lon_slice2]), axis=2)
        else:
            # Extract variable using determined indices
            PS = (ds_PS.variables['PS'][:,lat_slice,lon_slice])
    
        # Perform vertical interpolation to desired pressure level
        var_interp = Ngl.vinth2p((var),hyam,hybm,[plev],PS,1,P0,1,True)
        # Get rid of degenerate dimension
        out = np.squeeze(var_interp[:,0,:,:])
        
    # Process for extracting soil moisture data
    elif SM:
        # Read in SM data - first 8 layers (root zone)
        if passpm:
            var   = np.concatenate((ds_var.variables[variable][:,:8,lat_slice,lon_slice1], 
                    ds_var.variables[variable][:,:8,lat_slice,lon_slice2]), axis=3)
        else:
            var   = ds_var.variables[variable][:,:8,lat_slice,lon_slice]
        
        # Set missing values to NaNs
        var[var==ds_var.variables[variable]._FillValue] = np.NaN
        
        # Create masked array before computing weighted average 
        masked_var = np.ma.masked_array(var, np.isnan(var))                
        
        # Define array of CLM soil layer thicknesses up to layer 8
        thick   = np.array([0.02,0.04,0.06,0.08,0.12,0.16,0.2,0.24])
        # Compute layer thickness weights 
        wgts    = thick/thick.sum()
        # Perform weighted average using weights
        wgt_avg = np.ma.average(masked_var,axis=1,weights=wgts)
        # Set weighted average SM data as output array
        out     = wgt_avg.filled(np.NaN)
        
    # Do this in all other cases
    else:
        if passpm:
            out = np.concatenate((ds_var.variables[variable][:,lat_slice,lon_slice1],
                                  ds_var.variables[variable][:,lat_slice,lon_slice2]),axis=2)
        else:
            out = ds_var.variables[variable][:,lat_slice,lon_slice]
        
        out[out==ds_var.variables[variable]._FillValue] = np.NaN
    
    # Define xarray Data Array to hold data and metadata
    # In both cases, the convention of the longitude array is changed from 0 to 360 to -180 to 180
    if passpm:
        lon_values = np.asarray(pd.concat((lon[lon_slice1]-360.0,lon[lon_slice2]))) 
        array      = xr.DataArray(out, coords={'time':date_rng,'lat':np.array(lat[lat_slice].values), 
                                               'lon':lon_values}, dims=['time','lat','lon'])
    else:
        lon_values = np.asarray(lon[lon_slice])
        lon_values[lon_values>=180] = lon_values[lon_values>=180] - 360.
        array = xr.DataArray(out, coords={'time':date_rng,'lat':np.array(lat[lat_slice].values), 
                                          'lon':np.array(lon_values)}, dims=['time','lat','lon'])

    if plot:
        # Plot first time step to check data
        fig = plt.figure(figsize=(10, 6), edgecolor='w')
        m = Basemap(projection='cyl',
                llcrnrlat=-90, urcrnrlat=90,
                llcrnrlon=-180, urcrnrlon=180)
        lon, lat = np.meshgrid(lon_values, lat[lat_slice])
        m.pcolormesh(lon, lat, out[0,:,:],
             latlon=True, cmap='plasma')
        m.drawcoastlines()
        plt.colorbar()
        plt.show()

    
    return array