In [1]:
import os
import glob
import numpy as np 
import xarray as xr
import pandas as pd
import datetime
from datetime import date, timedelta
import dask
import re
import scipy.stats as stats
import scipy.signal as signal
from skimage.measure import find_contours
from statsmodels.tsa.stattools import acf, pacf

import isla_interp_utils as isla_interp

# Plotting utils 
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.ticker as mticker
import cartopy
import cartopy.crs as ccrs
import seaborn as sns


In [3]:
def regrid_data(fromthis, tothis, method=1):
    """Regrid data using various different methods"""

    #Import necessary modules:
    import xarray as xr

    if method == 1:
        # kludgy: spatial regridding only, seems like can't automatically deal with time
        if 'time' in fromthis.coords:
            result = [fromthis.isel(time=t).interp_like(tothis) for t,time in enumerate(fromthis['time'])]
            result = xr.concat(result, 'time')
            return result
        else:
            return fromthis.interp_like(tothis)
    elif method == 2:
        newlat = tothis['lat']
        newlon = tothis['lon']
        coords = dict(fromthis.coords)
        coords['lat'] = newlat
        coords['lon'] = newlon
        return fromthis.interp(coords)
    elif method == 3:
        newlat = tothis['lat']
        newlon = tothis['lon']
        ds_out = xr.Dataset({'lat': newlat, 'lon': newlon})
        regridder = xe.Regridder(fromthis, ds_out, 'bilinear')
        return regridder(fromthis)
    elif method==4:
        # geocat
        newlat = tothis['lat']
        newlon = tothis['lon']
        result = geocat.comp.linint2(fromthis, newlon, newlat, False)
        result.name = fromthis.name
        
        return result

In [4]:
# Grabbed from Brian M. to use time midpoints, not end periods
def cesm_correct_time(ds):
    """Given a Dataset, check for time_bnds,
       and use avg(time_bnds) to replace the time coordinate.
       Purpose is to center the timestamp on the averaging inverval.   
       NOTE: ds should have been loaded using `decode_times=False`
    """
    assert 'time_bnds' in ds
    assert 'time' in ds
    correct_time_values = ds['time_bnds'].mean(dim='nbnd')
    # copy any metadata:
    correct_time_values.attrs = ds['time'].attrs
    ds = ds.assign_coords({"time": correct_time_values})
    ds = xr.decode_cf(ds)  # decode to datetime objects
    return ds

# - - - - - - - - - - - - - - - 
# Pre-process data while reading in 
# - - - - - - - - - - - - - - - 

def preprocess(ds):
    dsCorr         = cesm_correct_time(ds)
    dsCorr         = dsCorr.sel(lat=slice(-10,10))
    
    return dsCorr


In [5]:
## Some basics - the region to focus on, for one
lat_n = 10.0
lat_s = -10.0

# Nino3.4
lat_n34 = 5
lat_s34 = -5
lon_e34 = 190 
lon_w34 = 240

# Nino3
lat_n3 = 5
lat_s3 = -5
lon_e3 = 210 
lon_w3 = 270

# Nino 4
lat_n4 = 5
lat_s4 = -5
lon_e4 = 160 
lon_w4 = 210


## Read in CESM1/CESM2 PI data

In [5]:
cesm2_dir = '/glade/campaign/collections/cdg/data/CMIP6/CMIP/NCAR/CESM2/piControl/r1i1p1f1/Amon/ts/gn/files/d20190320/'
cesm1_dir = '/glade/campaign/cesm/collections/cesmLE/CESM-CAM5-BGC-LE/atm/proc/tseries/monthly/TS/'

listFiles_cesm1 = np.sort(glob.glob(cesm1_dir+'b.e11.B1850C5CN.f09_g16.005.cam.h0.TS.*nc'))
listFiles_cesm2 = np.sort(glob.glob(cesm2_dir+'ts_Amon_CESM2_piControl_r1i1p1f1_gn*.nc'))

DS_all_cesm1 = xr.open_mfdataset(listFiles_cesm1, preprocess=preprocess, concat_dim='time', combine='nested', 
                                 decode_times=False, data_vars='minimal', parallel=True)

DS_all_cesm2 = xr.open_mfdataset(listFiles_cesm2, 
                                 preprocess=preprocess, concat_dim='time', combine='nested', 
                                 decode_times=False, data_vars='minimal', parallel=True)


  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(
  new_vars[k] = decode_cf_variable(


In [7]:
%%time 

nYears_cesm1 = len(DS_all_cesm1.time.values)/12
nYears_cesm2 = len(DS_all_cesm2.time.values)/12
print('CESM1 has %i years.\nCESM2 has %i years.' % (nYears_cesm1,nYears_cesm2) )

# n_spacing    = int(25)
# nYearsPer    = int(50)
n_spacing    = int(40)
nYearsPer    = int(40)
# nSamples  = int(np.floor(np.nanmin([((nYears_cesm1 - nYearsPer) // n_spacing + 1), ((nYears_cesm2 - nYearsPer) // n_spacing + 1)])) )
nSamples  = int(np.floor(np.nanmin([((nYears_cesm1 - nYearsPer) // n_spacing + 1), ((nYears_cesm2 - nYearsPer) // n_spacing + 1)])) )
print('Creating %i samples to compute ENSO over, starting every %i years and each with a window of %i years.' % (nSamples, n_spacing, nYearsPer))

#nStartInd = np.arange(0, nSamples*12, n_spacing*12)

# Create an empty array to store the reshaped data
events_cesm1 = np.zeros([nSamples, nYearsPer*12, len(DS_all_cesm1.lat.values), len(DS_all_cesm1.lon.values)])
events_cesm2 = np.zeros([nSamples, nYearsPer*12, len(DS_all_cesm2.lat.values), len(DS_all_cesm2.lon.values)])

# Create a new DataArray with the 'Event' axis
event_coords = np.arange(nSamples)
time_coords = np.arange(nYearsPer*12)

# DS_events_cesm1 = xr.DataArray(events_cesm1, coords=[event_coords, time_coords, DS_all_cesm1.lat.values, DS_all_cesm1.lon.values], 
#                                dims=["event", "time", "lat","lon"])

# DS_events_cesm2 = xr.DataArray(events_cesm2, coords=[event_coords, time_coords, DS_all_cesm2.lat.values, DS_all_cesm2.lon.values], 
                               # dims=["event", "time", "lat","lon"])

## Loop over the events and fill the new array
for iENSO in range(nSamples):
    start_year = iENSO * (n_spacing*12)
    events_cesm1[iENSO, :, :,:] = DS_all_cesm1.TS.isel(time=slice(start_year, (start_year + nYearsPer*12)))
    events_cesm2[iENSO, :, :,:] = DS_all_cesm2.ts.isel(time=slice(start_year, (start_year + nYearsPer*12)))

    # print('Starting with time index %i' % (start_year))




CESM1 has 1801 years.
CESM2 has 1200 years.
Creating 30 samples to compute ENSO over, starting every 40 years and each with a window of 40 years.
CPU times: user 35 s, sys: 1.98 s, total: 37 s
Wall time: 1min


In [8]:
# Create a new DataArray with the 'Event' axis
event_coords = np.arange(nSamples)
# time_coords  = DS_all_cesm1.time.values[17400:18000]
time_coords  = DS_all_cesm1.time.values[17400:(17400+(nYearsPer*12))]

DS_cesm1_events = xr.DataArray(events_cesm1, coords=[event_coords, time_coords, DS_all_cesm1.lat.values, DS_all_cesm1.lon.values], 
                               dims=["event", "time", "lat","lon"])


DS_cesm2_events = xr.DataArray(events_cesm2, coords=[event_coords, time_coords, DS_all_cesm2.lat.values, DS_all_cesm2.lon.values], 
                               dims=["event", "time", "lat","lon"])


## Read in CESM3dev data

In [2]:
phis = xr.open_dataset("/glade/campaign/cesm/collections/CESM2-LE/atm/proc/tseries/month_1/PHIS/"
     +"b.e21.BHISTcmip6.f09_g17.LE2-1001.001.cam.h0.PHIS.185001-185912.nc").isel(time=0).load()

In [7]:
def preprocess_h0(DS):
    climoVar_list = ['LHFLX','SHFLX','LWCF','SWCF','PRECT','PS','TAUX','TAUY','TGCLDLWP','U10','TREFHT','UBOT','TS']

    dsSel = DS.sel(lat=slice(-10,10))

    ## Interpolate to set levels 
    u850 = isla_interp.interp_hybrid_to_pressure(
      dsSel.U, dsSel.PS, dsSel.hyam, dsSel.hybm, p0=1e5, new_levels = np.array([85000.]), method='log', 
      lev_dim='lev', extrapolate=False, variable='other',
      t_bot = dsSel.T.isel(lev=dsSel.lev.size-1), phi_sfc = phis)

    omega500 = isla_interp.interp_hybrid_to_pressure(
      dsSel.OMEGA, dsSel.PS, dsSel.hyam, dsSel.hybm, p0=1e5, new_levels = np.array([50000.]), method='log', 
      lev_dim='lev', extrapolate=False, variable='other',
      t_bot = dsSel.T.isel(lev=dsSel.lev.size-1), phi_sfc = phis)
    
    
    # u850  = DS.sel(lev=850, method='nearest').U
    dsSel = dsSel[climoVar_list]
    dsSel['U850'] = u850
    dsSel['OMEGA500'] = omega500
    
    return dsSel

In [None]:
dataDir = '/glade/derecho/scratch/hannay/archive/'
caseNames = ['b.e30_beta04.BLT1850.ne30_t232_wgx3.121', 
             'b.e30_beta04.BLTHIST.ne30_t232_wgx3.121', 
             'b.e30_beta04.BLT1850.ne30_t232_wgx3.121_1pctco2',
             'b.e30_beta04.BLT1850.ne30_t232_wgx3.121_4xco2',
             'b.e30_beta05.BLT1850.ne30_t232_wgx3.125',
            ]
shortNames = ['121_preInd', 
              '121_hist',
              '121_1pctCO2',
              '121_4xCO2',
              '125',
             ]

for iCase in range(len(caseNames)):
    listFiles = np.sort(glob.glob(dataDir+caseNames[iCase]+'/atm/hist/'+'*.h0a.*.nc'))

    camDS = xr.open_mfdataset(listFiles, preprocess=preprocess_h0, concat_dim='time', combine='nested', 
                                 decode_times=True, data_vars='minimal', parallel=True)

    caseDS   = camDS.squeeze().assign_coords({"case":  shortNames[iCase]})

    if iCase==0:
        camDS_all = caseDS
    else: 
        camDS_all = xr.concat([camDS_all, caseDS], "case") 

    

## Get anomalies and SSTs

In [None]:
## Compute anomalies

# Detrend data 
SST = cesm2_sst.ts.values
##   This is a little wonky because signal.detrend can't handle NaNs so replacing with a marker that we mask later
sst_detrend = signal.detrend(SST, axis=0, type='linear')

PRECT = cesm2_pr.pr 
prect_detrend = signal.detrend(PRECT, axis=0, type='linear')

U850 = cesm2_u850.ua
##   This is a little wonky because signal.detrend can't handle NaNs so replacing with a marker that we mask later
u850_Fill = U850.values
u850_Fill[(np.isnan(u850_Fill))] = -99
u850_detrend = signal.detrend(u850_Fill, axis=0, type='linear')

TAUX = cesm2_taux.tauu
taux_detrend = signal.detrend(TAUX, axis=0, type='linear')


# Get ocean values only 
cesm2_prect = prect_detrend * cesm2_ocnMask
cesm2_sst   = sst_detrend * cesm2_ocnMask
cesm2_u850  = u850_detrend * cesm2_ocnMask
cesm2_taux  = taux_detrend * cesm2_ocnMask

# Also remove annual cycle
cesm2_sst = xr.DataArray(cesm2_sst, 
    coords={'time': cesm2_pr.time.values,
            'lat':  cesm2_pr.lat.values, 
            'lon':  cesm2_pr.lon.values}, 
    dims=["time", "lat", "lon"])

cesm2_prect = xr.DataArray(cesm2_prect, 
    coords={'time': cesm2_pr.time.values,
            'lat':  cesm2_pr.lat.values, 
            'lon':  cesm2_pr.lon.values}, 
    dims=["time", "lat", "lon"])


cesm2_u850 = xr.DataArray(cesm2_u850, 
    coords={'time': cesm2_pr.time.values,
            'lat':  cesm2_pr.lat.values, 
            'lon':  cesm2_pr.lon.values}, 
    dims=["time", "lat", "lon"])

cesm2_taux = xr.DataArray(cesm2_taux, 
    coords={'time': cesm2_pr.time.values,
            'lat':  cesm2_pr.lat.values, 
            'lon':  cesm2_pr.lon.values}, 
    dims=["time", "lat", "lon"])


sst_cesm2_anom   = rmMonAnnCyc(cesm2_sst)
prect_cesm2_anom = rmMonAnnCyc(cesm2_prect)
u850_cesm2_anom  = rmMonAnnCyc(cesm2_u850)
taux_cesm2_anom  = rmMonAnnCyc(cesm2_taux)
