In [1]:
import glob
import xarray as xr
import pandas as pd
import os
import numpy as np
import warnings

In [2]:
def fixlons(nci,latdim,latstr,londim,lonstr):
# This function makes lat/lon coordinate names
# and convention consistent between models
    lonarray = np.zeros(nci[lonstr].shape)
    ndim = lonarray.ndim
    if ndim == 2:
        if float(nci[lonstr].min()) < -1.:
            for i in range(nci[lonstr].shape[0]):
                for j in range(nci[lonstr].shape[1]):
                    if float(nci[lonstr][i,j]) < 0.:
                        lonarray[i,j] = nci[lonstr][i,j].data + 360.
                    else:
                        lonarray[i,j] = nci[lonstr][i,j].data
            nci[lonstr] = ([latdim, londim], lonarray)
    elif ndim == 1:
        if float(nci[lonstr].min()) < -1.:
            for i in range(nci[lonstr].shape[0]):
                if float(nci[lonstr][i]) < 0.:
                    lonarray[i] = nci[lonstr][i].data + 360.
                else:
                    lonarray[i] = nci[lonstr][i].data
            nci[lonstr] = ([londim], lonarray)
    nci = nci.rename({lonstr:'lon',latstr:'lat'})
    return nci

In [3]:
def read_all_ens(nci,variable,fili,variants,model,exp):
# This function reads all ensemble members stored for a given model
# and concatenates them into a single dataset for one variable at a time
    # Variables:
        # nci: dictionary into which the datasets are stored
        # variable: name of variable in CMIP6 file
        # fili: sorted list of CMIP6 files to be read from
        # variants: list of ensemble variant names to be read
        # model: name of CMIP6 model
        # exp: name of CMIP6 experiment
    print('|=====| Beginning to read data for ' + variable)
    print('|=====| (nens = ' + str(len(variants)) + ')')
    nfil = len(fili)
    if exp == 'historical':
        timslic = slice('1900-01-01','2014-12-30')
    elif exp in ['ssp245','ssp370','ssp585']:
        timslic = slice('2015-01-01','2100-12-30')
    else:
        print('Invalid experiment name')
        return
    ie = 0
    cc = 0
    oddlatlon = False
    for ifil in range(nfil):
        chckvar = variants[ie]
        vari = fili[ifil].split('/')[-3]
        if chckvar == vari:
            if cc == 0:
                print('|=====|=====| Beginning to read in ensemble member: ' + variants[ie])
                tmp = xr.open_dataset(fili[ifil],engine='netcdf4',drop_variables=['lat_bnds','lon_bnds','time_bnds','plev_bnds'])
                if model == 'EC-Earth3':
                    tmp=tmp.load()
                latstr,lonstr,latdim,londim = coordnames(tmp)
                if latstr != 'lat' or lonstr != 'lon':
                    tmp = tmp.rename({lonstr: 'lon',latstr: 'lat'})
                    oglonstr = lonstr
                    oglatstr = latstr
                    lonstr = 'lon'
                    latstr = 'lat'
                    oddlatlon = True
                tmp = fixlons(tmp,latdim,latstr,londim,lonstr)
                if variable == 'zg':
                    tmp = tmp.sel(plev=[50000.,20000.,],method='nearest')
                    tmp = tmp.assign_coords(plev = [50000.,20000.])
                cc += 1
            else:
                tmp2 = xr.open_dataset(fili[ifil], engine='netcdf4',drop_variables=['lat_bnds','lon_bnds','time_bnds','plev_bnds'])
                if oddlatlon:
                    tmp2 = tmp2.rename({oglonstr: 'lon',oglatstr: 'lat'})
                tmp2 = fixlons(tmp2,latdim,latstr,londim,lonstr)
                if variable == 'zg':
                    tmp2 = tmp2.sel(plev=[50000.,20000.,],method='nearest')
                    tmp2 = tmp2.assign_coords(plev = [50000.,20000.])
                if tmp['lat'][0] != tmp2['lat'][0]:
                    tmp2['lat'] = tmp['lat']
                tmp = xr.concat( [tmp, tmp2], dim='time')
            if ifil != nfil-1:
                if chckvar != fili[ifil+1].split('/')[-3] and ie == 0:
                    tmp = tmp.sel(time=timslic)
                    nci[variable] = tmp
                    nci[variable] = nci[variable].expand_dims(dim=dict(ens=[variants[ie]]))
                    ie += 1
                    cc = 0
                elif chckvar != fili[ifil+1].split('/')[-3]:
                    tmp = tmp.sel(time=timslic)
                    if tmp['lat'][0] != nci[variable]['lat'][0]:
                        tmp['lat'] = nci[variable]['lat']
                    tmp = tmp.expand_dims(dim=dict(ens=[variants[ie]]))
                    nci[variable] = xr.concat( [nci[variable], tmp], dim='ens')
                    ie += 1
                    cc = 0
            elif ifil == nfil-1 and ie == 0:
                tmp = tmp.sel(time=timslic)
                tmp = tmp.expand_dims(dim=dict(ens=[variants[ie]]))
                nci[variable] = tmp
            else:
                tmp = tmp.sel(time=timslic)
                if tmp['lat'][0] != nci[variable]['lat'][0]:
                    tmp['lat'] = nci[variable]['lat']
                tmp = tmp.expand_dims(dim=dict(ens=[variants[ie]]))
                nci[variable] = xr.concat( [nci[variable], tmp], dim='ens')
    return nci[variable]
    print('|=====| Successfully read in data for ' + variable)


In [4]:
def coordnames(nci):
# This function determines what coordinate variable names
# are used by the input dataset
    if 'latitude' in list(nci.variables):
        latstr = 'latitude'
    elif 'lat' in list(nci.variables):
        latstr = 'lat'
    elif 'nav_lat' in list(nci.variables):
        latstr = 'nav_lat'

    if 'longitude' in list(nci.variables):
        lonstr = 'longitude'
    elif 'lon' in list(nci.variables):
        lonstr = 'lon'
    elif 'nav_lon' in list(nci.variables):
        lonstr = 'nav_lon'

    if 'lat' in list(nci.dims):
        latdim = 'lat'    
        londim = 'lon'
    elif 'latitude' in list(nci.dims):
        latdim = 'latitude'
        londim = 'longitude'
    elif 'nav_lat' in list(nci.dims):
        latdim = 'nav_lat'
        londim = 'nav_lon'
    elif 'x' in list(nci.dims):
        latdim = 'y'
        londim = 'x'
    else:
        latdim = 'j'
        londim = 'i'
    return latstr,lonstr,latdim,londim

In [None]:
def prep_cmip6(model,experiment):
# This function stores pr and tas CMIP6 data for a given model and experiment
# into a single netCDF file with consistent coordinate variable names
# Also computes ELI and Nino3.4 index

    if model == 'ICON-ESM-LR':
        print('|=====|=====|=====|=====|=====|=====|=====|=====|=====|=====|=====|')
        print(model + ' has a weird grid.  Sort it out later')
        return
    elif model == 'MIROC-ES2H':
        print('|=====|=====|=====|=====|=====|=====|=====|=====|=====|=====|=====|')
        print(model + ' data only includes 1850.  Skipping.')
        return
    print('|=====|=====|=====|=====|=====|=====|=====|=====|=====|=====|=====|')
    print('Beginning preprocessing of CMIP6 data for model: ' + model + ' ' + exp)

    diri = '/glade/scratch/nlybarger/data/climate_data/cmip6/'+ experiment + '/Amon/'
    nci = {}
    filo = '/glade/scratch/nlybarger/data/climate_data/cmip6/postproc/' + model + '.' + experiment + '.nc'

    if os.path.exists(filo):
        print('Computation already completed for model: ' + model + ' ' + experiment)
        return
    
    if experiment == 'historical' and model == 'EC-Earth3':
        variants = ['r10i1p1f1','r11i1p1f1','r12i1p1f1','r13i1p1f1','r14i1p1f1','r15i1p1f1',
                    'r16i1p1f1','r17i1p1f1','r18i1p1f1','r19i1p1f1','r1i1p1f1',
                    'r21i1p1f1','r22i1p1f1','r23i1p1f1','r24i1p1f1','r25i1p1f1','r2i1p1f1',
                    'r3i1p1f1','r4i1p1f1','r6i1p1f1','r7i1p1f1','r9i1p1f1']
    if experiment == 'ssp245' and model == 'EC-Earth3':
        variants = ['r101i1p1f1','r102i1p1f1','r103i1p1f1','r104i1p1f1','r105i1p1f1',
                    'r106i1p1f1','r107i1p1f1','r108i1p1f1','r109i1p1f1','r10i1p1f1',
                    'r10i1p1f2','r110i1p1f1','r111i1p1f1','r112i1p1f1','r113i1p1f1',
                    'r114i1p1f1','r115i1p1f1','r116i1p1f1','r117i1p1f1','r118i1p1f1',
                    'r119i1p1f1','r11i1p1f1','r120i1p1f1','r121i1p1f1','r122i1p1f1',
                    'r123i1p1f1','r124i1p1f1','r125i1p1f1','r126i1p1f1','r127i1p1f1',
                    'r128i1p1f1','r129i1p1f1','r130i1p1f1','r131i1p1f1',
                    'r132i1p1f1','r133i1p1f1','r134i1p1f1','r135i1p1f1','r136i1p1f1',
                    'r137i1p1f1','r138i1p1f1','r139i1p1f1','r13i1p1f1','r13i1p1f2',
                    'r140i1p1f1','r141i1p1f1','r142i1p1f1','r143i1p1f1','r144i1p1f1',
                    'r145i1p1f1','r146i1p1f1','r147i1p1f1','r148i1p1f1','r149i1p1f1',
                    'r150i1p1f1','r15i1p1f1','r16i1p1f2','r18i1p1f2','r1i1p1f1','r20i1p1f2',
                    'r22i1p1f2','r24i1p1f2','r26i1p1f2','r28i1p1f2',
                    'r2i1p1f2','r4i1p1f1','r6i1p1f1','r6i1p1f2','r7i1p1f2']
    if experiment == 'ssp370' and model == 'EC-Earth3':
        variants = ['r101i1p1f1','r102i1p1f1','r103i1p1f1','r104i1p1f1','r105i1p1f1','r106i1p1f1',
                    'r107i1p1f1','r108i1p1f1','r109i1p1f1','r110i1p1f1','r111i1p1f1','r112i1p1f1','r113i1p1f1',
                    'r114i1p1f1','r115i1p1f1','r116i1p1f1','r117i1p1f1','r118i1p1f1','r119i1p1f1','r11i1p1f1',
                    'r120i1p1f1','r121i1p1f1','r122i1p1f1','r123i1p1f1','r124i1p1f1','r125i1p1f1','r126i1p1f1',
                    'r127i1p1f1','r128i1p1f1','r129i1p1f1','r130i1p1f1','r131i1p1f1','r132i1p1f1','r133i1p1f1',
                    'r134i1p1f1','r135i1p1f1','r136i1p1f1','r137i1p1f1','r138i1p1f1','r139i1p1f1','r13i1p1f1',
                    'r140i1p1f1','r141i1p1f1','r142i1p1f1','r143i1p1f1','r144i1p1f1','r145i1p1f1','r146i1p1f1',
                    'r147i1p1f1','r148i1p1f1','r149i1p1f1','r150i1p1f1','r15i1p1f1','r1i1p1f1','r4i1p1f1','r6i1p1f1','r9i1p1f1']
    if experiment == 'historical' and model == 'EC-Earth3-Veg':
        variants = ['r10i1p1f1','r12i1p1f1','r14i1p1f1','r1i1p1f1','r2i1p1f1','r3i1p1f1','r4i1p1f1','r6i1p1f1']
    if experiment == 'ssp245' and model == 'EC-Earth3-Veg':
        variants = ['r12i1p1f1','r14i1p1f1','r1i1p1f1','r2i1p1f1','r3i1p1f1','r4i1p1f1','r6i1p1f1']

    if experiment == 'ssp585' and model in ['ACCESS-CM2','ACCESS-ESM1-5','IPSL-CM6A-LR','MIROC-ES2L','MRI-ESM2-0']:
        filipr = sorted(glob.glob(diri + 'pr/' + model + '/*/*/*2015*'))
    elif experiment in ['historical','ssp245'] and model in ['EC-Earth3','EC-Earth3-Veg']:
        butt = 0
        if butt == 0:
            filipr = []
            butt = 1
        for var in variants:
            filipr.extend(sorted(glob.glob(diri + 'pr/' + model + '/' + var + '/*/*.nc')))
    else:
        filipr = sorted(glob.glob(diri + 'pr/' + model + '/*/*/*.nc'))
    
    if not filipr:
        print('pr does not exist for ' + model + ' ' + exp)
        return
    
# Just need to skip this code chunk for these models/experiments 
# due to explicitly defining their ensemble members
    if experiment in ['historical','ssp245','ssp370'] and model == 'EC-Earth3':
        print('lol')
    elif experiment in ['historical','ssp245'] and model == 'EC-Earth3-Veg':
        print('lol')
    else:
        variants = ['0']
        i=0
        for ifil in range(len(filipr)):
            variant = filipr[ifil].split('/')[-3]
            if variant not in variants:
                if i > 0:
                    variants.append(variant)
                    i += 1
                else:
                    variants[i] = variant
                    i += 1
    nvar = len(variants)

    nci['pr'] = read_all_ens(nci,'pr',filipr,variants,model,experiment)
# Some models had superfluous files floating around, 
# so took special exceptions to pick out the right time period
    if experiment == 'ssp585' and model in ['ACCESS-CM2','ACCESS-ESM1-5','IPSL-CM6A-LR','MIROC-ES2L','MRI-ESM2-0']:
        filitas = sorted(glob.glob(diri + 'tas/' + model + '/*/*/*2015*'))
    elif experiment in ['historical','ssp245'] and model in ['EC-Earth3','EC-Earth3-Veg']:
        butt = 0
        if butt == 0:
            filitas = []
            butt = 1
        for var in variants:
            filitas.extend(sorted(glob.glob(diri + 'tas/' + model + '/' + var + '/*/*')))
    else:
        filitas = sorted(glob.glob(diri + 'tas/' + model + '/*/*/*'))

    nci['tas'] = read_all_ens(nci,'tas',filitas,variants,model,experiment)

    varlist = list(nci.keys())
    latstr,lonstr,latdim,londim = coordnames(nci['pr'])
    
# Compute ELI and Nino3.4 from 2-m air temperature
    # 2-m air temp has been shown to be a very good proxy for SST, and
    # is always on the same grid as the atmospheric variables, so is 
    # easier to work with
    ntim = len(nci['tas']['time'])
    with warnings.catch_warnings():
        warnings.simplefilter("ignore",category=RuntimeWarning)

        troptas = nci['tas']['tas'].sel(lat=slice(-5.,5.),drop=True).mean(dim=[latdim,londim],skipna=True)
        tmp  = nci['tas']['tas'].sel(lat=slice(-5.,5.),lon=slice(130.,275.),drop=True)
        troptasval = np.zeros((nvar,ntim,tmp.shape[2],tmp.shape[3]))
        for i in range(tmp.shape[2]):
            for j in range(tmp.shape[3]):
                troptasval[:,:,i,j] = troptas.data
        sstanom = tmp - troptasval
        print('|=====| Beginning computation of ELI')
        eli_mon = np.zeros((nvar,ntim))
        if sstanom[lonstr].ndim == 1:
            londat = np.zeros((sstanom[latstr].shape[0],sstanom[lonstr].shape[0]))
            for i in range(sstanom[latstr].shape[0]):
                londat[i,:] = sstanom[lonstr]

            lonny = xr.DataArray(
                        data = londat,
                        dims = [latdim, londim],
                        coords = dict(
                            londim = ([londim], sstanom[londim].data),
                            latdim = ([latdim], sstanom[latdim].data)),)
        else:
            lonny = sstanom[lonstr]
        for it in range(ntim):
            for iens in range(nvar):
                eli_mon[iens,it] = lonny.where(sstanom[iens,it,:,:] > 0., drop=True).mean(skipna=True)

        n34_sst = nci['tas'].sel(lat=slice(-5,5),lon=slice(190,240),drop=True)

        tmp = n34_sst.groupby('time.month') - n34_sst.groupby('time.month').mean(dim='time')
        tmp = tmp.mean(dim=[latdim,londim],skipna=True)
        n34 = tmp.rolling(time=5,center=True).mean(dim='time')
        n34['tas'][:,0] = (tmp['tas'][:,0] + tmp['tas'][:,1] + tmp['tas'][:,2])/3
        n34['tas'][:,1] = (tmp['tas'][:,0] + tmp['tas'][:,1] + tmp['tas'][:,2] + tmp['tas'][:,3])/4
        n34['tas'][:,2] = (tmp['tas'][:,0] + tmp['tas'][:,1] + tmp['tas'][:,2] + tmp['tas'][:,3] + tmp['tas'][:,4])/5
        n34['tas'][:,3] = (tmp['tas'][:,1] + tmp['tas'][:,2] + tmp['tas'][:,3] + tmp['tas'][:,4] + tmp['tas'][:,5])/5
        n34['tas'][:,4] = (tmp['tas'][:,2] + tmp['tas'][:,3] + tmp['tas'][:,4] + tmp['tas'][:,5] + tmp['tas'][:,6])/5

        n34['tas'][:,-1] = (tmp['tas'][:,-1] + tmp['tas'][:,-2] + tmp['tas'][:,-3])/3
        n34['tas'][:,-2] = (tmp['tas'][:,-1] + tmp['tas'][:,-2] + tmp['tas'][:,-3] + tmp['tas'][:,-4])/4
        n34['tas'][:,-3] = (tmp['tas'][:,-1] + tmp['tas'][:,-2] + tmp['tas'][:,-3] + tmp['tas'][:,-4] + tmp['tas'][:,-5])/5
        n34['tas'][:,-4] = (tmp['tas'][:,-2] + tmp['tas'][:,-3] + tmp['tas'][:,-4] + tmp['tas'][:,-5] + tmp['tas'][:,-6])/5
        n34['tas'][:,-5] = (tmp['tas'][:,-3] + tmp['tas'][:,-4] + tmp['tas'][:,-5] + tmp['tas'][:,-6] + tmp['tas'][:,-7])/5
        print(n34)

        enso_ind = xr.Dataset(
        data_vars = dict(
                    eli=(['ens','time'], eli_mon),
                    n34=(['ens','time'], n34['tas'].data),
                    ),
                coords = dict(
                    time = nci['tas']['time'].data,
                    ens = variants,
                    ),
                attrs=dict(description='ELI and Nino 3.4 data from: '+model),
                    )
        print('|=====| Completed computation of ELI')

    nci['pr']['pr'] = nci['pr']['pr']*86400

    nco = nci['pr']['pr'].to_dataset()
    nco = nco.assign(tas=nci['tas']['tas'])
    nco = nco.assign(eli=enso_ind['eli'])
    nco = nco.assign(n34=enso_ind['n34'])

    nco.to_netcdf(filo,mode='w',format='NETCDF4')

    del nco
    del nci
    del n34_sst
    del n34
    del troptas
    del troptasval
    del tmp
    del sstanom
    del enso_ind
    
    print('Completed preprocessing of CMIP6 data for model: ' + model)

In [None]:
for exp in ['historical','ssp245','ssp370','ssp585']:
    diri = '/glade/scratch/nlybarger/data/climate_data/cmip6/'+exp+'/Amon/'
    filis = sorted(glob.glob(diri + 'tas/*'))
    nfil = len(filis)

    models = ['0']*nfil
    for i in range(nfil):
        models[i] = filis[i].split('/')[-1]
    for mod in models:
        prep_cmip6(mod,exp)