In [None]:
import numpy as np
import glob
import warnings
import datetime
import scipy.stats
import matplotlib.pyplot as plt
import os
import xarray as xr
import xskillscore as xs
import pandas as pd
import xesmf as xesmf
import operator

In [None]:
def pnw_average(dataArray,latstr,lonstr):
    #Compute the latitude weighted average over the PNW
    datout = dataArray.where((dataArray[latstr] > 40.5) & (dataArray[latstr] < 50.5) &
                            (dataArray[lonstr] > 235.5) & (dataArray[lonstr] < 253.5),drop=True)
    weights = np.cos(np.deg2rad(datout[latstr]))
    weights.name = 'weights'
    datout_weighted = datout.weighted(weights).mean(dim=[latstr,lonstr],skipna=True)
    return datout_weighted

def regional_average(dataArray,lonmin,lonmax,lonstr,latmin,latmax,latstr):
    #Compute the latitude weighted average over the region of interest
    datout = dataArray.where((dataArray[latstr] > 40.5) & (dataArray[latstr] < 50.5) &
                            (dataArray[lonstr] > 235.5) & (dataArray[lonstr] < 253.5),drop=True)
    weights = np.cos(np.deg2rad(datout[latstr]))
    weights.name = 'weights'
    datout_weighted = datout.weighted(weights).mean(dim=[latstr,lonstr],skipna=True)
    return datout_weighted

In [None]:
def coordnames(nci):
    if 'latitude' in list(nci.variables):
        latstr = 'latitude'
    elif 'lat' in list(nci.variables):
        latstr = 'lat'
    elif 'nav_lat' in list(nci.variables):
        latstr = 'nav_lat'

    if 'longitude' in list(nci.variables):
        lonstr = 'longitude'
    elif 'lon' in list(nci.variables):
        lonstr = 'lon'
    elif 'nav_lon' in list(nci.variables):
        lonstr = 'nav_lon'

    if 'lat' in list(nci.dims):
        latdim = 'lat'
        londim = 'lon'
    elif 'latitude' in list(nci.dims):
        latdim = 'latitude'
        londim = 'longitude'
    elif 'nav_lat' in list(nci.dims):
        latdim = 'nav_lat'
        londim = 'nav_lon'
    elif 'x' in list(nci.dims):
        latdim = 'y'
        londim = 'x'
    else:
        latdim = 'j'
        londim = 'i'
    return latstr,lonstr,latdim,londim

In [None]:
#lil function to alter longitude to span [0,360] if it spans [-180,180]
def fixlons(nci,latdim,londim,lonstr):
    lonarray = np.zeros(nci[lonstr].shape)
    ndim = lonarray.ndim
    if ndim == 2:
        if float(nci[lonstr].min()) < -1.:
            for i in range(nci[lonstr].shape[0]):
                for j in range(nci[lonstr].shape[1]):
                    if float(nci[lonstr][i,j]) < 0.:
                        lonarray[i,j] = nci[lonstr][i,j].data + 360.
                    else:
                        lonarray[i,j] = nci[lonstr][i,j].data
            nci[lonstr] = ([latdim, londim], lonarray)
    elif ndim == 1:
        if float(nci[lonstr].min()) < -1.:
            for i in range(nci[lonstr].shape[0]):
                if float(nci[lonstr][i]) < 0.:
                    lonarray[i] = nci[lonstr][i].data + 360.
                else:
                    lonarray[i] = nci[lonstr][i].data
            nci[lonstr] = ([londim], lonarray)
    return nci

In [None]:
#computes the linear trend of precipitation and temperature (only used for SSP calculations)
def compute_trend(nci):
    #this bit is just because xarray doesn't recognize single valued dimensions
    if 'ens' in nci.variables:
        nens = len(nci['ens'])
    else:
        nens = 1
    month_length = nci.time.dt.days_in_month
    tmp = (nci['pr']*month_length)/10

    p = pnw_average(tmp,'lat','lon')
    p = p.groupby('time.year').sum(dim='time',skipna=False)
    t = pnw_average(nci['tas'],'lat','lon')
    t = t.groupby('time.year').mean(dim='time')
    ptmp = xs.linslope(p['year'],p,dim='year',skipna=False)*100
    ttmp = xs.linslope(t['year'],t,dim='year',skipna=False)*100
    ptrend = np.nanmean(ptmp)
    ttrend = np.nanmean(ttmp)
    
    return ptrend,ttrend,ptmp

In [None]:
def seasonal_avg_vars(nci,model,latstr,lonstr,obs):
# Computes seasonal averages for pr, tas, eli, and n34
    years = list(nci.groupby('time.year').groups)
    nyr = len(years)
    seaskeys = ['DJF','MAM','JJA','SON']
    drs = {}
    if obs == False:
        ens = list(nci['ens'])
        nens = len(ens)
    for seas in seaskeys:
        drs[seas] = {}
    for iy in range(nyr-1):
        drs['DJF'][iy] = slice(str(years[iy])+'-12-01',str(years[iy+1])+'-02-28')
        drs['MAM'][iy] = slice(str(years[iy+1])+'-03-01',str(years[iy+1])+'-05-30')
        drs['JJA'][iy] = slice(str(years[iy+1])+'-06-01',str(years[iy+1])+'-08-30')
        drs['SON'][iy] = slice(str(years[iy+1])+'-09-01',str(years[iy+1])+'-11-30')

    nci['pranom'] = nci['pr'].groupby('time.month') - nci['pr'].groupby('time.month').mean(dim='time')
    nci['tasanom'] = nci['tas'].groupby('time.month') - nci['tas'].groupby('time.month').mean(dim='time')
    outvars = {}
    eli = {}
    pr = {}
    pranom = {}
    tas = {}
    tasanom = {}
    n34 = {}
    if obs == False:
        for seas in seaskeys:
            eli[seas] = np.zeros((nens,nyr-1))
            n34[seas] = np.zeros((nens,nyr-1))
            pr[seas]  = np.zeros((nens,nyr-1,len(nci[latstr]),len(nci[lonstr])))
            pranom[seas]  = np.zeros((nens,nyr-1,len(nci[latstr]),len(nci[lonstr])))
            tas[seas] = np.zeros((nens,nyr-1,len(nci[latstr]),len(nci[lonstr])))
            tasanom[seas] = np.zeros((nens,nyr-1,len(nci[latstr]),len(nci[lonstr])))

            for iy in range(nyr-1):
                eli[seas][:,iy]     = nci['eli'].sel(time=drs[seas][iy]).mean(dim='time').values
                n34[seas][:,iy]     = nci['n34'].sel(time=drs[seas][iy]).mean(dim='time').values
                pranom[seas][:,iy,:,:]  = nci['pranom'].sel(time=drs[seas][iy]).mean(dim='time').values
                tasanom[seas][:,iy,:,:] = nci['tasanom'].sel(time=drs[seas][iy]).mean(dim='time').values
                pr[seas][:,iy,:,:]      = nci['pr'].sel(time=drs[seas][iy]).mean(dim='time').values
                tas[seas][:,iy,:,:]     = nci['tas'].sel(time=drs[seas][iy]).mean(dim='time').values

            outvars[seas] = xr.Dataset(
                data_vars = dict(
                    eli=(['ens','time'], eli[seas]),
                    n34=(['ens','time'], n34[seas]),
                    pr=(['ens','time','lat','lon'],pr[seas]),
                    pranom=(['ens','time','lat','lon'],pranom[seas]),
                    tas=(['ens','time','lat','lon'],tas[seas]),
                    tasanom=(['ens','time','lat','lon'],tasanom[seas]),
                ),
                coords = dict(
                    ens = (['ens'], nci['ens'].data),
                    time = pd.date_range(str(years[0]+1)+'-01-01', periods=nyr-1, freq='AS'),
                    lat = (['lat'],nci[latstr].data),
                    lon = (['lon'],nci[lonstr].data),
                ),
                attrs=dict(description= seas + ' average variables from: ' + model),
            )
    elif obs == True:
        for seas in seaskeys:
            eli[seas] = np.zeros(nyr-1)
            n34[seas] = np.zeros(nyr-1)
            pr[seas]  = np.zeros((nyr-1,len(nci[latstr]),len(nci[lonstr])))
            pranom[seas]  = np.zeros((nyr-1,len(nci[latstr]),len(nci[lonstr])))
            tas[seas] = np.zeros((nyr-1,len(nci[latstr]),len(nci[lonstr])))
            tasanom[seas] = np.zeros((nyr-1,len(nci[latstr]),len(nci[lonstr])))

            for iy in range(nyr-1):
                eli[seas][iy]           = nci['eli'].sel(time=drs[seas][iy]).mean(dim='time').values
                n34[seas][iy]           = nci['n34'].sel(time=drs[seas][iy]).mean(dim='time').values
                pranom[seas][iy,:,:]    = nci['pranom'].sel(time=drs[seas][iy]).mean(dim='time').values
                tasanom[seas][iy,:,:]   = nci['tasanom'].sel(time=drs[seas][iy]).mean(dim='time').values
                pr[seas][iy,:,:]        = nci['pr'].sel(time=drs[seas][iy]).mean(dim='time').values
                tas[seas][iy,:,:]       = nci['tas'].sel(time=drs[seas][iy]).mean(dim='time').values

            outvars[seas] = xr.Dataset(
                data_vars = dict(
                    eli=(['time'], eli[seas]),
                    n34=(['time'], n34[seas]),
                    pr=(['time','lat','lon'],pr[seas]),
                    pranom=(['time','lat','lon'],pranom[seas]),
                    tas=(['time','lat','lon'],tas[seas]),
                    tasanom=(['time','lat','lon'],tasanom[seas]),
                ),
                coords = dict(
                    time = pd.date_range(str(years[0]+1)+'-01-01', periods=nyr-1, freq='AS'),
                    lat = (['lat'],nci[latstr].data),
                    lon = (['lon'],nci[lonstr].data),
                ),
                attrs=dict(description= seas + ' average variables from: ' + model),
            )
    return outvars

In [None]:
# Reading in CMIP6 datasets and parsing model/variant information
diri='/glade/u/home/nlybarger/scratch/data/climate_data/cmip6/postproc/'
filis = sorted(glob.glob(diri + '*historical.nc'))
nfil = len(filis)
models = ['0']*nfil
for i in range(nfil):
    models[i] = filis[i].split('/')[-1].split('.')[0]

i=0
nci = {}
variants = {}
for fil in filis:
    nci[models[i]] = xr.open_dataset(fil,engine='netcdf4')
    variants[models[i]] = list(nci[models[i]]['ens'].data)
    i += 1
firstrun = True

In [None]:
# Read in observational datasets

odiri = '/glade/work/nlybarger/data/OBS/'
odsets = ['CRU','ERA-5','GMET','UDel','Livneh','PRISM']
fnames = ['cru','era5','gmetensm','udel','livneh','prism',]
obs = {}
oyears = {}
for i in range(len(odsets)):
    obs[odsets[i]] = xr.open_dataset(odiri + odsets[i] + '/1deg.' + fnames[i] + '.wconus.p.t.nc',engine='netcdf4')
    oyears[odsets[i]] = list(obs[odsets[i]].groupby('time.year').groups)

n34f = '/glade/work/nlybarger/data/clim_indices/nino34.1870-2021.txt'
fp = open(n34f,'r')
n34o = np.genfromtxt(fp,delimiter=',',usecols=np.arange(1,13),dtype='f4')
n34o = np.reshape(n34o[30:150,:],(120*12))
fp.close()

elifi = '/glade/work/nlybarger/data/clim_indices/ELI_ERSSTv5_1854.01-2019.12.csv'
fp = open(elifi,'r')
elio = np.genfromtxt(fp,delimiter=',',usecols=np.arange(47,167),dtype='f4',skip_header=1)
elio = np.transpose(elio)
elio = np.reshape(elio,(120*12,))
fp.close()

indy = xr.Dataset(
        data_vars = dict(
            eli=(['time'], elio),
            n34=(['time'], n34o),
        ),
        coords = dict(
            time=(['time'], pd.date_range('1900-01-01','2019-12-31',freq='MS')),
        ),
)

oseasvars = {}
for dset in odsets:
    print(dset)
    if dset in ['CRU','ERA-5','GMET','PRISM']:
        oyears[dset] = oyears[dset][:-2]

    obs[dset] = obs[dset].sel(time=slice(str(oyears[dset][0])+'-01-01',str(oyears[dset][-1])+'-12-31'))
    obs[dset]['n34'] = (['time'],indy['n34'].sel(time=slice(str(oyears[dset][0])+'-01-01',str(oyears[dset][-1])+'-12-31')).data)
    obs[dset]['eli'] = (['time'],indy['eli'].sel(time=slice(str(oyears[dset][0])+'-01-01',str(oyears[dset][-1])+'-12-31')).data)
    oseasvars[dset] = seasonal_avg_vars(obs[dset],'obs','lat','lon',True)

In [None]:
# Compute metrics for each observational dataset

dmeanto = {}
dsampto = {}
dmeanpo = {}
dsamppo = {}
ddjf_corrs_obs = {}
dptrendo = {}
dttrendo = {}

for dset in odsets:
    nyr = len(oyears[dset])
# Mean-T
    gbto = pnw_average(obs[dset]['tas'],'lat','lon')
    dmeanto[dset] = pnw_average(obs[dset]['tas'].groupby('time.year').mean(dim='time',skipna=True),'lat','lon')
    dmeanto[dset] = dmeanto[dset].mean()

# Seasonal Amplitude-T
    it=0
    dsampto[dset] = np.zeros(nyr)
    for year in oyears[dset]:
        tmp = gbto.sel(time=slice(str(year)+'-01-01',str(year)+'-12-31'))
        dsampto[dset][it] = tmp.max()-tmp.min()
        it+=1
    dsampto[dset] = dsampto[dset].mean()

# Mean-P
    tmp = obs[dset]['pr']/10
    gbpo = pnw_average(tmp,'lat','lon')
    dmeanpo[dset] = pnw_average(tmp.groupby('time.year').sum(dim='time',skipna=False),'lat','lon')
    dmeanpo[dset] = dmeanpo[dset].mean()

# Seasonal Amplitude-P
    it=0
    dsamppo[dset] = np.zeros(nyr)
    for year in oyears[dset]:
        tmp = gbpo.sel(time=slice(str(year)+'-01-01',str(year)+'-12-31'))
        dsamppo[dset][it] = tmp.max()-tmp.min()
        it+=1
    dsamppo[dset] = dsamppo[dset].mean()

# Nino3.4/ELI - variable Anomalies DJF
    pnwlat = slice(35,55)
    pnwlon = slice(230,258)

    ddjf_corrs_obs[dset] = {}
    ddjf_corrs_obs[dset]['n34pr'] = xs.pearson_r(oseasvars[dset]['DJF']['n34'],oseasvars[dset]['DJF']['pranom'].sel(lat=pnwlat,lon=pnwlon,drop=True),dim='time')
    ddjf_corrs_obs[dset]['elipr'] = xs.pearson_r(oseasvars[dset]['DJF']['eli'],oseasvars[dset]['DJF']['pranom'].sel(lat=pnwlat,lon=pnwlon,drop=True),dim='time')
    ddjf_corrs_obs[dset]['n34t'] = xs.pearson_r(oseasvars[dset]['DJF']['n34'],oseasvars[dset]['DJF']['tasanom'].sel(lat=pnwlat,lon=pnwlon,drop=True),dim='time')
    ddjf_corrs_obs[dset]['elit'] = xs.pearson_r(oseasvars[dset]['DJF']['eli'],oseasvars[dset]['DJF']['tasanom'].sel(lat=pnwlat,lon=pnwlon,drop=True),dim='time')
    if dset in ['CRU','UDel']:
        tmp = obs[dset]['pr'].sel(time=slice('1901-01-01','2014-12-30'),drop=True)/10
        ptmpo = pnw_average(tmp.groupby('time.year').sum(dim='time',skipna=False),'lat','lon')
        ttmpo = pnw_average((obs[dset]['tas'].sel(time=slice('1901-01-01','2014-12-30'),drop=True)).groupby('time.year').mean(dim='time',skipna=True),'lat','lon')

        dptrendo[dset] = (xs.linslope(ptmpo['year'], ptmpo, dim='year')*100).data
        dttrendo[dset] = (xs.linslope(ttmpo['year'], ttmpo, dim='year')*100).data

# Reference grid that all models and obs are regridded to for comparison
dummy1deg = xr.Dataset(
        data_vars = dict(
    ),
    coords = dict(
        lon = (['lon'], np.arange(236,253)),
        lat = (['lat'], np.arange(41,50)),
    ),
)

In [None]:
# Compute observational mean metrics

dco = {}
meanto = np.zeros(1)
sampto = np.zeros(1)
meanpo = np.zeros(1)
samppo = np.zeros(1)

ptrendo = np.zeros(1)
ttrendo = np.zeros(1)

nd = len(odsets)
ndt=0
for met in ['n34pr','elipr','n34t','elit']:
    dco[met] = np.zeros(1)
for dset in odsets:
    meanto = meanto + dmeanto[dset].data
    sampto = sampto + dsampto[dset].data
    meanpo = meanpo + dmeanpo[dset].data
    samppo = samppo + dsamppo[dset].data
    for met in ['n34pr','elipr','n34t','elit']:
        dco[met] = dco[met] + ddjf_corrs_obs[dset][met].data
    if dset in ['CRU','UDel']:
        ndt += 1
        ptrendo = ptrendo + dptrendo[dset]
        ttrendo = ttrendo + dttrendo[dset]

meanto = meanto/nd
sampto = sampto/nd
meanpo = meanpo/nd
samppo = samppo/nd
ptrendo = ptrendo/ndt
ttrendo = ttrendo/ndt
for met in ['n34pr','elipr','n34t','elit']:
    dco[met] = dco[met]/nd

seasvars_obs = {}
avgr = {}
avgr['DJF'] = [1,2,12]
avgr['MAM'] = [3,4,5]
avgr['JJA'] = [6,7,8]
avgr['SON'] = [9,10,11]

dseasvars_obs = {}
dseas_scorrs = {}
dseas_stdevs = {}
for seas in ['DJF','MAM','JJA','SON']:
    seasvars_obs[seas] = {}
    dseasvars_obs[seas] = {}

    i = 0
    for dset in odsets:
        dseasvars_obs[seas][dset] = {}
        dseasvars_obs[seas][dset]['tas'] = obs[dset]['tas'].sel(lat=pnwlat,lon=pnwlon,drop=True).groupby('time.month').mean(dim='time').sel(month=avgr[seas],drop=True).mean(dim='month')
        dseasvars_obs[seas][dset]['pr'] = obs[dset]['pr'].sel(lat=pnwlat,lon=pnwlon,drop=True).groupby('time.month').mean(dim='time').sel(month=avgr[seas],drop=True).mean(dim='month')
        if i==0:
            seasvars_obs[seas]['tas'] = dseasvars_obs[seas][dset]['tas']
            seasvars_obs[seas]['pr'] = dseasvars_obs[seas][dset]['pr']
            i += 1
        else:
            seasvars_obs[seas]['tas'] = seasvars_obs[seas]['tas'] + dseasvars_obs[seas][dset]['tas']
            seasvars_obs[seas]['pr'] = seasvars_obs[seas]['pr'] + dseasvars_obs[seas][dset]['pr']
    seasvars_obs[seas]['tas'] = seasvars_obs[seas]['tas']/nd
    seasvars_obs[seas]['pr'] = seasvars_obs[seas]['pr']/nd

    dseas_scorrs[seas] = np.full((len(odsets),len(odsets),2),np.nan)
    dseas_stdevs[seas] = np.full((len(odsets),len(odsets),2),np.nan)
    for i in range(len(odsets)):
        dset = odsets[i]
        for j in range(len(odsets)):
            dset2 = odsets[j]
            if j==i:
                continue
            else:
                dseas_scorrs[seas][i,j,0] = xs.pearson_r(dseasvars_obs[seas][dset]['tas'],dseasvars_obs[seas][dset2]['tas'],dim=['lat','lon'],skipna=True).data
                dseas_scorrs[seas][i,j,1] = xs.pearson_r(dseasvars_obs[seas][dset]['pr'],dseasvars_obs[seas][dset2]['pr'],dim=['lat','lon'],skipna=True).data
                dseas_stdevs[seas][i,j,0] = np.nanstd(dseasvars_obs[seas][dset]['tas'].data)/np.nanstd(dseasvars_obs[seas][dset2]['tas'].data)
                dseas_stdevs[seas][i,j,1] = np.nanstd(dseasvars_obs[seas][dset]['pr'].data)/np.nanstd(dseasvars_obs[seas][dset2]['pr'].data)
    dseas_scorrs[seas] = np.nanmean(dseas_scorrs[seas],axis=1)
    dseas_stdevs[seas] = np.nanmean(dseas_stdevs[seas],axis=1)
djf_corrs_obs = xr.Dataset(
        data_vars = dict(
        n34pr = (['lat','lon'], dco['n34pr']),
        elipr = (['lat','lon'], dco['elipr']),
        n34t = (['lat','lon'], dco['n34t']),
        elit = (['lat','lon'], dco['elit']),
    ),
    coords = dict(
        lon = (['lon'], seasvars_obs['DJF']['tas']['lon'].data),
        lat = (['lat'], seasvars_obs['DJF']['tas']['lat'].data),
    ),
)
ensomets = ['n34pr','elipr','n34t','elit']
obs_enso_corrs = {}
for imet in range(4):
    met = ensomets[imet]
    obs_enso_corrs[met] = np.full((6,6),np.nan)
    for i in range(len(odsets)):
        dset = odsets[i]
        for j in range(len(odsets)):
            dset2 = odsets[j]
            if i==j:
                continue
            else:
                obs_enso_corrs[met][i,j] = xs.pearson_r(ddjf_corrs_obs[dset][met],ddjf_corrs_obs[dset2][met],skipna=True).data
    obs_enso_corrs[met] = np.nanmean(obs_enso_corrs[met],axis=1)

In [None]:
nmet = 28

obsmet = np.full((nd,nmet),np.nan)
for i in range(nd):
    dset = odsets[i]
    
## Mean-T, Mean-P, Seasonal Amplitude-T, Seasonal Amplitude-P
    j=0
    obsmet[i,j] = dmeanto[dset]-meanto.item()
    j+=1
    obsmet[i,j] = dmeanpo[dset]-meanpo.item()
    j+=1
    obsmet[i,j] = dsampto[dset]
    j+=1
    obsmet[i,j] = dsamppo[dset]
    j+=1
    
#P-Trend and T-Trend only computed for CRU and UDel
    if dset in ['CRU','UDel']:
        obsmet[i,j] = dttrendo[dset]
        j+=1
        obsmet[i,j] = dptrendo[dset]
        j+=1
    else:
        obsmet[i,j] = np.nan
        j+=1
        obsmet[i,j] = np.nan
        j+=1
## DJF Spatial Correlation with Obs for Nino3.4 and ELI for T and P
    obsmet[i,j] = obs_enso_corrs['n34pr'][i]
    j+=1
    obsmet[i,j] = obs_enso_corrs['elipr'][i]
    j+=1
    obsmet[i,j] = obs_enso_corrs['n34t'][i]
    j+=1
    obsmet[i,j] = obs_enso_corrs['elit'][i]
    j+=1
## Seasonal Spatial Correlation for T and P with Obs
    obsmet[i,j] = dseas_scorrs['DJF'][i,0]
    j+=1
    obsmet[i,j] = dseas_scorrs['MAM'][i,0]
    j+=1
    obsmet[i,j] = dseas_scorrs['JJA'][i,0]
    j+=1
    obsmet[i,j] = dseas_scorrs['SON'][i,0]
    j+=1
    obsmet[i,j] = dseas_scorrs['DJF'][i,1]
    j+=1
    obsmet[i,j] = dseas_scorrs['MAM'][i,1]
    j+=1
    obsmet[i,j] = dseas_scorrs['JJA'][i,1]
    j+=1
    obsmet[i,j] = dseas_scorrs['SON'][i,1]
    j+=1
## Seasonal Spatial Standard Deviation for T and P with Obs
    obsmet[i,j] = dseas_stdevs['DJF'][i,0]
    j+=1
    obsmet[i,j] = dseas_stdevs['MAM'][i,0]
    j+=1
    obsmet[i,j] = dseas_stdevs['JJA'][i,0]
    j+=1
    obsmet[i,j] = dseas_stdevs['SON'][i,0]
    j+=1
    obsmet[i,j] = dseas_stdevs['DJF'][i,1]
    j+=1
    obsmet[i,j] = dseas_stdevs['MAM'][i,1]
    j+=1
    obsmet[i,j] = dseas_stdevs['JJA'][i,1]
    j+=1
    obsmet[i,j] = dseas_stdevs['SON'][i,1]
print(j+1)

In [None]:
from matplotlib.patches import Polygon
fig = plt.figure()
ll_lat = latty[0]
ur_lat = latty[-1]
ll_lon = lonny[0]
ur_lon = lonny[-1]
px,py = np.meshgrid(lonny,latty)
m = Basemap(resolution='i',projection='merc',llcrnrlat=ll_lat,llcrnrlon=ll_lon,
            urcrnrlat=ur_lat,urcrnrlon=ur_lon)
x,y = m(px,py)
lats = [ 40.5, 49.5, 49.5, 40.5 ]
lons = [ 234.5, 234.5, 253.5, 253.5 ]
xr, yr = m( lons, lats )
xryr = zip(xr,yr)
poly = Polygon( list(xryr),facecolor='none',edgecolor='red',linewidth = 3, alpha=1.0 )
plt.gca().add_patch(poly)
m.drawcoastlines()
m.drawstates()
m.drawcountries()
m.drawparallels(np.arange(int(ll_lat),int(ur_lat),5.),labels=[1,0,0,0])
m.drawmeridians(np.arange(int(ll_lon-1),int(ur_lon-1),10.),labels=[0,0,0,1])
fig.patch.set_facecolor('w')
plt.title('PNW Evaluation Domain')
plt.savefig('/glade/work/nlybarger/data/hydromet/ESM_eval_semifinal_plots/final/domain.png',dpi=300,bbox_inches='tight',facecolor='w')

### Metric Checklist
- MeanT
- MeanP
- SeasonAmpT
- SeasonAmpP
- DJF_ELI_med_bias
- DJF_ELI_Levene
- SpaceCor_N34_P
- SpaceCor_ELI_P
- SpaceCor_N34_T
- SpaceCor_ELI_T
- SpaceCor - MMMT
- SpaceCor - MMMP (expanded domain)
- SpaceSD - MMMT
- SpaceSD - MMMP (expanded domain)

In [None]:
# output directory
diro = '/glade/work/nlybarger/data/hydromet/cmip6_metrics/PNW/'

im = 0
for mod in models:
    if os.path.exists(diro + mod + '.cmip6.metrics.PNW.nc'):
        print('Metrics already computed for ' + mod + '.  Advancing.')
        continue
    nvar = len(variants[mod])
    im+=1
    print(str(im) + ': Beginning computation of verification metrics for model: ' + mod)
    if mod == 'NorCPM1':
        print(mod + ' raises an error from the ESMF regridder.  Skipping.')
        continue
    latstr,lonstr,latdim,londim = coordnames(nci[mod])

# converts tas to degrees C if in K
    if nci[mod]['tas'].max() > 100.:
        nci[mod]['tas'] = nci[mod]['tas']-273.15
    month_length = nci[mod].time.dt.days_in_month
    if firstrun == True:
        nci[mod]['pr'] = nci[mod]['pr']*month_length
# Mean-T
    gbt = pnw_average(nci[mod]['tas'],latstr,lonstr)
    meant = pnw_average(nci[mod]['tas'].groupby('time.year').mean(dim='time',skipna=True),latstr,lonstr)
    meant = meant.mean(dim='year',skipna=True)

# Seasonal Amplitude-T
    years = list(nci[mod].groupby('time.year').groups)
    sampt = np.zeros((nvar,len(years)))
    it=0
    for year in years:
        tmp=gbt.sel(time=slice(str(year)+'-01-01',str(year)+'-12-30'))
        for iv in range(nvar):
            tmp0 = tmp.sel(ens=variants[mod][iv])
            sampt[iv,it] = tmp0.max()-tmp0.min()
        it+=1
    sampt = sampt.mean(axis=1)

# Mean-P
    tmp = nci[mod]['pr']/10  # convert units from mm/day to cm/mo a la Rupp et al
    gbp = pnw_average(tmp,latstr,lonstr)
    meanp = pnw_average(tmp.groupby('time.year').sum(dim='time',skipna=False),latstr,lonstr)
    meanp = meanp.mean(dim='year',skipna=True)
        
# Seasonal Amplitude-P
    years = list(nci[mod].groupby('time.year').groups)
    sampp = np.zeros((nvar,len(years)))
    it=0
    for year in years:
        tmp=gbp.sel(time=slice(str(year)+'-01-01',str(year)+'-12-30'))
        for iv in range(nvar):
            tmp0 = tmp.sel(ens=variants[mod][iv])
            sampp[iv,it] = tmp0.max()-tmp0.min()
        it+=1
    sampp = sampp.mean(axis=1)

# ELI median bias and Levene's statistic
    seasvars = seasonal_avg_vars(nci[mod],mod,latstr,lonstr,False)
    elimed = np.zeros(nvar)
    levstat = np.zeros(nvar)
    for iv in range(nvar):
        elimed[iv] = np.median(seasvars['DJF']['eli'].sel(ens=variants[mod][iv]))
        levstat[iv],_ = scipy.stats.levene(seasvars['DJF']['eli'].sel(ens=variants[mod][iv]).data,
                                                oseasvars['CRU']['DJF']['eli'].data,center='median')
        
# Correlations between ELI/N34 and pr/tas
    djf_corrs_mod = {}
    djf_corrs_mod['n34pr'] = xs.pearson_r(seasvars['DJF']['n34'],
                                          seasvars['DJF']['pranom'].sel(lat=pnwlat,lon=pnwlon,drop=True),dim='time')
    djf_corrs_mod['elipr'] = xs.pearson_r(seasvars['DJF']['eli'],
                                          seasvars['DJF']['pranom'].sel(lat=pnwlat,lon=pnwlon,drop=True),dim='time')
    djf_corrs_mod['n34t'] = xs.pearson_r(seasvars['DJF']['n34'],
                                         seasvars['DJF']['tasanom'].sel(lat=pnwlat,lon=pnwlon,drop=True),dim='time')
    djf_corrs_mod['elit'] = xs.pearson_r(seasvars['DJF']['eli'],
                                         seasvars['DJF']['tasanom'].sel(lat=pnwlat,lon=pnwlon,drop=True),dim='time')
    djf_mets = list(djf_corrs_mod.keys())

    regridder_pnw = xesmf.Regridder(seasvars['DJF'].sel(ens=variants[mod][0],lat=pnwlat,lon=pnwlon,drop=True),
                                    dummy1deg,'bilinear')
    for met in djf_mets:
        djf_corrs_mod[met] = regridder_pnw(djf_corrs_mod[met])
        djf_corrs_mod[met] = djf_corrs_mod[met].transpose('ens', 'lat', 'lon')
    
    djf_enso_scorrs = np.zeros((len(djf_mets),nvar,nd))
    for imet in range(4):
        met = djf_mets[imet]
        for ie in range(nvar):
            tmp = djf_corrs_mod[met].sel(ens=variants[mod][ie],drop=True)
            for idset in range(len(odsets)):
                djf_enso_scorrs[imet,ie,idset] = xs.pearson_r(tmp,ddjf_corrs_obs[dset][met],skipna=True).data
    djf_enso_scorrs = np.nanmean(djf_enso_scorrs,axis=2)
    
# Mean Seasonal average spatial correlation and standard deviation
    seaslist = ['DJF','MAM','JJA','SON']
    seas_scorrs = np.zeros((2,4,nvar,nd))
    seas_sstdev = np.zeros((2,4,nvar,nd))
    seastas = {}
    seaspr = {}
    for isea in range(4):
        seas=seaslist[isea]
        seastas[seas] = nci[mod]['tas'].sel(lat=pnwlat,lon=pnwlon,drop=True).groupby('time.month').mean(dim='time').sel(month=avgr[seas],drop=True).mean(dim='month')
        seastas[seas] = regridder_pnw(seastas[seas])
        tmpobs = seasvars_obs[seas]['tas']
        for ie in range(nvar):
            seastas[seas][ie,:,:] = xr.where(~np.isnan(tmpobs),seastas[seas][ie,:,:],np.nan)
        seaspr[seas]  = nci[mod]['pr'].sel(lat=pnwlat,lon=pnwlon,drop=True).groupby('time.month').mean(dim='time').sel(month=avgr[seas],drop=True).mean(dim='month')
        seaspr[seas] = regridder_pnw(seaspr[seas])
        tmpobs = seasvars_obs[seas]['pr']
        for ie in range(nvar):
            seaspr[seas][ie,:,:] = xr.where(~np.isnan(tmpobs),seaspr[seas][ie,:,:],np.nan)
        for ie in range(nvar):
            for idset in range(len(odsets)):
                seas_scorrs[0,isea,ie,idset] = xs.pearson_r(seastas[seas].sel(ens=variants[mod][ie],drop=True),dseasvars_obs[seas][odsets[idset]]['tas'],dim=['lat','lon'],skipna=True).data
                seas_sstdev[0,isea,ie,idset] = np.nanstd(seastas[seas].sel(ens=variants[mod][ie],drop=True).data)/np.nanstd(dseasvars_obs[seas][odsets[idset]]['tas'].data)
        for ie in range(nvar):
            for idset in range(len(odsets)):
                seas_scorrs[1,isea,ie,idset] = xs.pearson_r(seaspr[seas].sel(ens=variants[mod][ie],drop=True),dseasvars_obs[seas][odsets[idset]]['pr'],dim=['lat','lon'],skipna=True).data
                seas_sstdev[1,isea,ie,idset] = np.nanstd(seaspr[seas].sel(ens=variants[mod][ie],drop=True).data)/np.nanstd(dseasvars_obs[seas][odsets[idset]]['pr'].data)
    seas_scorrs = np.nanmean(seas_scorrs,axis=3)
    seas_sstdev = np.nanmean(seas_sstdev,axis=3)

    tmp = nci[mod]['pr']/10
    nens = len(nci[mod]['ens'])
    pt = pnw_average(tmp.sel(time=slice('1901-01-01','2014-12-30'),drop=True),'lat','lon')
    pt = pt.groupby('time.year').sum(dim='time',skipna=False)
    tt = pnw_average(nci[mod]['tas'].sel(time=slice('1901-01-01','2014-12-30'),drop=True),'lat','lon')
    tt = tt.groupby('time.year').mean(dim='time')
    ptrend = xs.linslope(pt['year'],pt,dim='year',skipna=False)*100
    ttrend = xs.linslope(tt['year'],tt,dim='year',skipna=False)*100

    metrics = xr.Dataset(
                data_vars = dict(
                    meant      = (['ens'], meant.data),
                    sampt      = (['ens'], sampt),
                    meanp      = (['ens'], meanp.data),
                    sampp      = (['ens'], sampp),
                    ttrend     = (['ens'], ttrend.data),
                    ptrend     = (['ens'], ptrend.data),

                    elimed = (['ens'], elimed),
                    eli_djf = (['ens','time'], seasvars['DJF']['eli'].data),
                    levstat    = (['ens'], levstat),

                    n34pr_rdjf = (['ens'], djf_enso_scorrs[0,:]),
                    elipr_rdjf = (['ens'], djf_enso_scorrs[1,:]),
                    n34t_rdjf  = (['ens'], djf_enso_scorrs[2,:]),
                    elit_rdjf  = (['ens'], djf_enso_scorrs[3,:]),

                    djf_t_r    = (['ens'], seas_scorrs[0,0,:]),
                    djf_pr_r   = (['ens'], seas_scorrs[1,0,:]),
                    mam_t_r    = (['ens'], seas_scorrs[0,1,:]),
                    mam_pr_r   = (['ens'], seas_scorrs[1,1,:]),
                    jja_t_r    = (['ens'], seas_scorrs[0,2,:]),
                    jja_pr_r   = (['ens'], seas_scorrs[1,2,:]),
                    son_t_r    = (['ens'], seas_scorrs[0,3,:]),
                    son_pr_r   = (['ens'], seas_scorrs[1,3,:]),

                    djf_t_sd   = (['ens'], seas_sstdev[0,0,:]),
                    djf_pr_sd  = (['ens'], seas_sstdev[1,0,:]),
                    mam_t_sd   = (['ens'], seas_sstdev[0,1,:]),
                    mam_pr_sd  = (['ens'], seas_sstdev[1,1,:]),
                    jja_t_sd   = (['ens'], seas_sstdev[0,2,:]),
                    jja_pr_sd  = (['ens'], seas_sstdev[1,2,:]),
                    son_t_sd   = (['ens'], seas_sstdev[0,3,:]),
                    son_pr_sd  = (['ens'], seas_sstdev[1,3,:]),
                    
                    djf_t  = (['ens','lat','lon'], seastas['DJF'].data),
                    djf_pr = (['ens','lat','lon'], seaspr['DJF'].data),
                    mam_t  = (['ens','lat','lon'], seastas['MAM'].data),
                    mam_pr = (['ens','lat','lon'], seaspr['MAM'].data),
                    jja_t  = (['ens','lat','lon'], seastas['JJA'].data),
                    jja_pr = (['ens','lat','lon'], seaspr['JJA'].data),
                    son_t  = (['ens','lat','lon'], seastas['SON'].data),
                    son_pr = (['ens','lat','lon'], seaspr['SON'].data),
                    
                    djf_n34_pr_scorr = (['ens','lat','lon'], djf_corrs_mod['n34pr'].data),
                    djf_eli_pr_scorr = (['ens','lat','lon'], djf_corrs_mod['elipr'].data),
                    djf_n34_t_scorr  = (['ens','lat','lon'], djf_corrs_mod['n34t'].data),
                    djf_eli_t_scorr  = (['ens','lat','lon'], djf_corrs_mod['elit'].data),
                    
                ),
                coords = dict(
                    ens  = (['ens'], nci[mod]['ens'].data),
                    time = (['time'], seasvars['DJF']['time'].data),
                    lat  = (['lat'], dummy1deg['lat'].data),
                    lon  = (['lon'], dummy1deg['lon'].data),
                ),
                attrs = dict(
                    description=('CMIP6 metrics for model: ' + mod))
            )
    metrics.to_netcdf(diro + mod + '.cmip6.metrics.PNW.nc',mode='w')
    
    del seasvars
    del metrics
    del seaspr
    del seastas
firstrun = False

In [None]:
if 'NorCPM1' in models:
    models.remove('NorCPM1')

diro = '/glade/work/nlybarger/data/hydromet/cmip6_metrics/PNW/'
metrics = {}
for mod in models:
    tmpfil = diro + mod + '.cmip6.metrics.PNW.nc'
    if os.path.exists(diro + mod + '.cmip6.metrics.PNW.nc'):
        metrics[mod] = xr.open_dataset(tmpfil,engine='netcdf4')
    else:
        continue

# Other MCM-UA-1-0 ensemble members show odd behavior
metrics['MCM-UA-1-0'] = metrics['MCM-UA-1-0'].isel(ens=[1])

# These models are removed from the list for various reasons
# Bizarre precipitation calculations, likely incorrect units in database
remmodels = ['KIOST-ESM','CIESM','E3SM-1-1','NorCPM1']
for mod in remmodels:
    if mod in list(metrics.keys()):
        metrics.pop(mod)
models = sorted(list(metrics.keys()))

In [None]:
# Number of metrics
nmet = 26

# Max number of ensemble members
nens = 72

# Number of models to be evaluated
nmod = len(models)

# Raw metric values
modmet = np.full((nmod,nmet,nens),np.nan)

# Absolute error matrix as compared to obs mean (or perfect correlation/stdev)
errs = np.full((nmod,nmet,nens),np.nan)

for i in range(nmod):
## Mean-T, Mean-P, Seasonal Amplitude-T, Seasonal Amplitude-P
    nensmod = len(metrics[models[i]]['ens'])
    j=0
    modmet[i,j,:nensmod] = metrics[models[i]]['meant']
    errs[i,j,:nensmod] = abs(modmet[i,j,:nensmod]-meanto.item())
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['meanp']
    errs[i,j,:nensmod] = abs(modmet[i,j,:nensmod]-meanpo.item())
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['sampt']
    errs[i,j,:nensmod] = abs(modmet[i,j,:nensmod]-sampto)
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['sampp']
    errs[i,j,:nensmod] = abs(modmet[i,j,:nensmod]-samppo)
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['ttrend']
    errs[i,j,:nensmod] = abs(modmet[i,j,:nensmod]-ttrendo)
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['ptrend']
    errs[i,j,:nensmod] = abs(modmet[i,j,:nensmod]-ptrendo)
    j+=1
## DJF Spatial Correlation with Obs for Nino3.4 and ELI for T and P
    modmet[i,j,:nensmod] = metrics[models[i]]['n34pr_rdjf']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['elipr_rdjf']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['n34t_rdjf']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['elit_rdjf']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
## Seasonal Spatial Correlation for T and P with Obs
    modmet[i,j,:nensmod] = metrics[models[i]]['djf_t_r']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['mam_t_r']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['jja_t_r']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['son_t_r']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['djf_pr_r']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['mam_pr_r']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['jja_pr_r']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['son_pr_r']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
## Seasonal Spatial Standard Deviation for T and P with Obs
    modmet[i,j,:nensmod] = metrics[models[i]]['djf_t_sd']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['mam_t_sd']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['jja_t_sd']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['son_t_sd']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['djf_pr_sd']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['mam_pr_sd']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['jja_pr_sd']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
    j+=1
    modmet[i,j,:nensmod] = metrics[models[i]]['son_pr_sd']
    errs[i,j,:nensmod] = abs(1-modmet[i,j,:nensmod])
print(j+1)

In [None]:
# Setting up metric plot titles
metric_titles = ['Mean-T\n°C','Mean-P\ncm/yr','SeasAmp-T\n°C','SeasAmp-P\ncm/mo',
                 'Trend-T\n°C/century','Trend-P\ncm/century',
                 'Nino3.4-pr r','ELI-pr r','Nino3.4-T r','ELI-T r',
                 'DJF-T','MAM-T','JJA-T','SON-T',
                 'DJF-P','MAM-P','JJA-P','SON-P',
                 'DJF-T','MAM-T','JJA-T','SON-T',
                 'DJF-P','MAM-P','JJA-P','SON-P']


In [None]:
# Plotting PNW metric comparisons

tlinst = 'c-'
modmet_ensm = np.nanmean(modmet,axis=2)
mask = ~np.isnan(modmet_ensm)
filty = [d[m] for d, m in zip(modmet_ensm.T, mask.T)]
i=0
xran = [.05,1.95]
fig,axs = plt.subplots(4,8,figsize=(11,10),sharex='all')
for ax in axs.flat:
    if i in [6,7,12,13,14,15]:
        ax.remove()
        i+=1
        continue
    if i<=5:
        ax.boxplot(filty[i],whiskerprops=dict(linestyle='-',linewidth=3),
                                    capprops=dict(linewidth=3),
                                    boxprops=dict(linewidth=3),
                                    medianprops=dict(linewidth=3),
                                    flierprops=dict(marker='o',markerfacecolor='r',markeredgecolor='k',linestyle='none'),
                                    widths=1.5)
    elif (i>=8) and (i<=11):
        ax.boxplot(filty[i-2],whiskerprops=dict(linestyle='-',linewidth=3),
                                    capprops=dict(linewidth=3),
                                    boxprops=dict(linewidth=3),
                                    medianprops=dict(linewidth=3),
                                    flierprops=dict(marker='o',markerfacecolor='r',markeredgecolor='k',linestyle='none'),
                                    widths=1.5)
    else:
        ax.boxplot(filty[i-6],whiskerprops=dict(linestyle='-',linewidth=3),
                                    capprops=dict(linewidth=3),
                                    boxprops=dict(linewidth=3),
                                    medianprops=dict(linewidth=3),
                                    flierprops=dict(marker='o',markerfacecolor='r',markeredgecolor='k',linestyle='none'),
                                    widths=1.5)
        
    #===================================================
    
    if i == 0:
        ax.plot(xran,[meanto.item(),meanto.item()],tlinst,linewidth=3,zorder=-1)
        #ax.set_ylabel('°C',fontsize=14)
        ax.set_ylim([0.,13.])
        ax.set_yticks([0,3,6,9,12])
    elif i == 1:
        ax.plot(xran,[meanpo.item(),meanpo.item()],tlinst,linewidth=3,zorder=-1)
        #ax.set_ylabel('cm/yr',fontsize=14)
        ax.set_ylim([14.,118.])
        ax.set_yticks([25,50,75,100])
        ax.yaxis.get_majorticklabels()[3].set_x(.065)
        ax.yaxis.set_label_coords(-.27, .5)
    elif i == 2:
        ax.plot(xran,[sampto,sampto],tlinst,linewidth=3,zorder=-1)
        #ax.set_ylabel('°C',fontsize=14)
        ax.set_ylim([16.6,32.])
        ax.set_yticks([20,24,28,32])
    elif i == 3:
        ax.plot(xran,[samppo,samppo],tlinst,linewidth=3,zorder=-1)
        #ax.set_ylabel('cm/mo',fontsize=14)
        ax.set_ylim([2.,17.])
        ax.set_yticks([5,10,15])
    elif i == 4:
        ax.plot(xran,[ttrendo,ttrendo],tlinst,linewidth=3,zorder=-1)
        #ax.set_ylabel('°C/century',fontsize=14)
        ax.set_ylim([-0.89,2.])
        ax.set_yticks([-0.,1,2])
    elif i == 5:
        ax.plot(xran,[ptrendo,ptrendo],tlinst,linewidth=3,zorder=-1)
        #ax.set_ylabel('cm/century', fontsize=14)
        ax.set_ylim([-7.,11.82])
        ax.set_yticks([-5,0,5,10])
        ax.yaxis.set_label_coords(-.27, .5)
    elif i in (np.arange(8,12)):
        ax.plot(xran,[1,1],tlinst,linewidth=3,zorder=-1)
        ax.set_ylim([-1.1,1.2])
        if i>8:
            ax.set_yticklabels([])
    elif i in (np.arange(16,24)):
        ax.plot(xran,[1,1],tlinst,linewidth=3,zorder=-1)
        ax.set_ylim([0.0,1.1])
        if i>16:
            ax.set_yticklabels([])
        if i==16:
            ax.set_ylabel('SpaceCorr',fontsize=14)
    elif i in (np.arange(24,32)):
        ax.plot(xran,[1,1],tlinst,linewidth=3,zorder=-1)
        ax.set_ylim([0.0,2.75])
        ax.set_yticks([0,0.5,1,1.5,2,2.5])
        if i>24:
            ax.set_yticklabels([])
        if i==24:
            ax.set_ylabel('SpaceSD',fontsize=14)
    else:
        ax.plot(xran,[1,1],tlinst,linewidth=3,zorder=-1)
        
    #===================================================
    
    if i<=5:
        ax.scatter(np.ones(modmet_ensm.shape[0]),modmet_ensm[:,i],c='k',s=30)
    elif (i>=8) and (i<=11):
        ax.scatter(np.ones(modmet_ensm.shape[0]),modmet_ensm[:,i-2],c='k',s=30)
    else:
        ax.scatter(np.ones(modmet_ensm.shape[0]),modmet_ensm[:,i-6],c='k',s=30)
        
    #===================================================
    
    if i == 0:
        ax.scatter(np.ones(obsmet.shape[0]),obsmet[:,i]+meanto.item(),c='cyan',s=75,edgecolors='k',zorder=5)
    elif i == 1:
        ax.scatter(np.ones(obsmet.shape[0]),obsmet[:,i]+meanpo.item(),c='cyan',s=75,edgecolors='k',zorder=5)
    elif i <= 5:
        ax.scatter(np.ones(obsmet.shape[0]),obsmet[:,i],c='cyan',s=75,edgecolors='k',zorder=5)
    elif (i>=8) and (i<=11):
        ax.scatter(np.ones(obsmet.shape[0]),obsmet[:,i-2],c='cyan',s=75,edgecolors='k',zorder=5)
    else:
        ax.scatter(np.ones(obsmet.shape[0]),obsmet[:,i-6],c='cyan',s=75,edgecolors='k',zorder=5)
        
    #===================================================
    
    metric_title_size = 13
    if i<=5:
        ax.set_title(metric_titles[i],fontsize=metric_title_size)
        
    elif (i>=8) and (i<=11):
        ax.set_title(metric_titles[i-2],fontsize=metric_title_size)
    else:
        ax.set_title(metric_titles[i-6],fontsize=metric_title_size)
        
    #===================================================
    
    ax.set_xlim(xran[0],xran[1])
    ax.set_xticks([])
    ax.tick_params(axis='both', which='major', labelsize=12)
    i += 1
    
fig.subplots_adjust(left=0.125,right=0.9,wspace=0.55,hspace=0.25)
plt.suptitle('Pacific Northwest Metrics',fontsize=22,y=0.98)
plt.savefig('/glade/u/home/nlybarger/CMIP6_1900-2014.PNW.metricarray.png',dpi=450,bbox_inches='tight',facecolor='w')

In [None]:
# Quick and dirty spatial resolution output
for mod in models:
    if mod in list(nci.keys()):
        print(str(round(abs((nci[mod]['lon'][0]-nci[mod]['lon'][1]).data),2)) +' x '+ str(round(abs((nci[mod]['lat'][0]-nci[mod]['lat'][1]).data),2)))
    else:
        print('uh what')