In [10]:
reset

Once deleted, variables cannot be recovered. Proceed (y/[n])?  y


In [1]:
import os
import sys
# block warnings from printing
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

import collections
import xarray as xr
xr.set_options(keep_attrs=True)
import netCDF4 as nc
import numpy as np
np.seterr(divide='ignore', invalid='ignore')
import metpy.calc as mp
from metpy.units import units
from scipy.stats import ttest_ind, ttest_rel
from datetime import datetime

import cartopy
cartopy.config['data_dir'] = "/discover/nobackup/projects/jh_tutorials/JH_examples/JH_datafiles/Cartopy"
cartopy.config['pre_existing_data_dir'] = "/discover/nobackup/projects/jh_tutorials/JH_examples/JH_datafiles/Cartopy"
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
from shapely.geometry.polygon import LinearRing

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.offsetbox import AnchoredText
import matplotlib.gridspec as gridspec
import matplotlib.ticker as mticker
import matplotlib.colors as mcolors
from matplotlib.colors import TwoSlopeNorm
from matplotlib import cm
from matplotlib.colors import ListedColormap,LinearSegmentedColormap
import cmocean.cm as cmo


# add path to custom functions
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path+"/py_functions")
# import custom functions
from map_plot_tools import *
from colorbar_funcs import *


# settings
%config InlineBackend.figure_format = 'retina'


In [2]:
### +++ FUNCTIONS +++ ###

## TIME-SERIES MEAN SEASONAL CLIMATOLOGIES ##
def season_mean(ds, calendar='standard'):
    seas_mean = ds.groupby('time.season').mean(dim='time') #sum(dim='time')
    return seas_mean

def annual_mean(ds, calendar='standard'):
    ann_mean = ds.groupby('time.year').mean(dim='time') #sum(dim='time')
    return ann_mean

## ANNUAL SEASONAL MEAN CLIMATOLOGIES ##
def annual_season_mean(ds, calendar='standard'):
    ds_seasonal = {}
    ann_seasonal_mean = {}
    for season in ['DJF','MAM','JJA','SON']:
        # extract data for season
        ds_seasonal[season] = ds.where(ds['time.season'] == season)
        # I'm not sure what this step is doing
        if season == 'DJF':
            ds_seasonal[season] = ds_seasonal[season].shift(time=1)
        # get timeseries of seasonal mean climatologies
        ann_seasonal_mean[season] = ds_seasonal[season].groupby('time.year').mean(dim='time') #sum(dim='time')
        # but, cut first year of DJF timeseries as there is no Dec data from year -1
        if season == 'DJF':
            year_min = ann_seasonal_mean[season].year.min()+1
            year_max = ann_seasonal_mean[season].year.max()
            ann_seasonal_mean[season] = ann_seasonal_mean[season].sel(year=slice(year_min, year_max)) #.isel(year=slice(1,len(ann_seas_mean['DJF'].year)))
    return ann_seasonal_mean


## STATISTICAL SIGNIFICANCE TEST ##
# test with identical sample sizes
def sigtest(yearmean1,yearmean2,timemean1,timemean2):
    ptvals = ttest_rel(yearmean1,yearmean2, axis=0)
    diff = timemean1-timemean2
    diff_mask = np.ma.masked_where(ptvals[1] > 0.1, diff)
    return diff, diff_mask, ptvals

# test with different sample sizes
def sigtest2n(yearmean1,yearmean2,timemean1,timemean2):
    ptvals = ttest_ind(yearmean1,yearmean2, axis=0, equal_var = False)
    diff = timemean1-timemean2
    diff_mask = np.ma.masked_where(ptvals[1] > 0.1, diff)
    return diff, diff_mask, ptvals

def latitude_weighted_mean(var):
    """
    Calculate the mean of geospatial data taking into account unequal grid cell area
    """
    # get x and y coordinate data
    lons,lats = get_xy_coords(var)
    # determine weight based on latitude value
    weights = np.cos(np.deg2rad(lats))
    weights.name = 'weights'
    # calculate area-weighted values
    weighted_var = var.weighted(weights)
    # calculate global mean of weighted data
    weighted_mean = weighted_var.mean(dim=[lats.name,lons.name], keep_attrs=True)
    return(weighted_mean)

def get_xy_coords(var):
    """
    Get lon and lat arrays without knowing coordinate names
    """
    if isinstance(var, xr.DataArray):
        x,y=var.metpy.coordinates('x','y')
        return(x,y)
    if isinstance(var, xr.Dataset):
        print('This is a dataset. Please use an xarray DataArray')
        
def longitude_flip(var):
    """ Convert longitude values from the -180:180 to 0:360 convention or vice versa.
        
        ** Only works for global data. Do not apply to data with a clipped longitude range **
        
        Parameters
        ----------
        var : Data Array
    """    
    # get var info
    x,_=get_xy_coords(var) # extract original longitude values
    lon_name=x.name        # store name of longitude coordinate
    nx=len(x)              # longitude resolution
    
    # determine longitude format and create an array of new lons in opposite convention
    if min(x)<0: 
        # if there are negative values, data is -180:180 and need to switch to 0:360
        new_lons=np.linspace((min(x)+180), (max(x)+180), nx)
    elif max(x)>180:
        # if the max value is >180, data is in 0:360 format and need to switch to -180:180
        new_lons=np.linspace((min(x)-180), (max(x)-180), nx)
        
    # shift the data by 180° of longitude
    nshift=nx//2
    var=var.roll({lon_name: nshift}, roll_coords=False)
            
    # update longitude coord with new values
    var=var.assign_coords({lon_name: new_lons})
    
    # add attributes documenting change
    timestamp=datetime.now().strftime("%B %d, %Y, %r")
    var.attrs['history']=f'flipped longitudes {timestamp}'
    var.attrs['original_lons']=x.values
    
    return(var)

In [3]:
### +++ DATA PATHS +++ ###

runNames=['E2pt1_PIctrl_restart', 'cam.sam.1senv.cpld'] # E2.1 cpld simulations
keys=['ctrl', 'cam+sam'] # product key
diag='w' # diagnostics

# store file paths in dictionary
dpath0='/discover/nobackup/projects/giss/baldwin_nip/dmkumar' # top level data directory
files={}

for i,key in enumerate(keys):
    run=runNames[i]
    files[key] = f'{dpath0}/{run}/timeseries/{diag}_timeseries.{run}.nc'

In [23]:
## ORGANIZE DATA AND CALCULATE CLIMATOLOGIES

diag='w' # diagnostics

# number of years to omit 
omit_yrs = 30
n_omit = omit_yrs * 12
tmax = 350 * 12

# lat/lon bounds
lonmin = 150
lonmax = 330
latmin = -55
latmax = 30

# new pressure levels
plevels = [1000,950,900,850,800,750,700,650,600,550,500,450,400,350,300,250,200,150,100,50]

# initialize dictionaries
dat = {}
time_mean = {}
ann_mean = {}
seas_mean = {}
ann_seas_mean = {}


## Data cleaning
print('Reading in data...')
for key in ['ctrl', 'cam+sam']:
    ds = xr.open_dataset(files[key], chunks={})[diag].isel(time=slice(n_omit,tmax))
    ds_flip = longitude_flip(ds)
    ds_clip = ds_flip.sel(lat=slice(latmin,latmax), lon=slice(lonmin,lonmax))
    dat[key] = ds_clip
    del ds
    del ds_flip
    del ds_clip
    
## Calculate time-means
print('Calculating time means for...')
for key in ['ctrl', 'cam+sam']:
    print(f'{key}')
    # initialize simulation sub-dicts
    time_mean[key] = {}
    ann_mean[key] = {}
    seas_mean[key] = {}
    ann_seas_mean[key] = {}
    
    for plot in ['500mb']:
        # interpolate pressure levels and extract 500mb data
        dat_interp = dat[key].interp(plm=plevels).sel(plm=500)
        time_mean[key][plot] = dat_interp.mean(dim='time')
        ann_mean[key][plot] = annual_mean(dat_interp)
        seas_mean[key][plot] = season_mean(dat_interp)
        ann_seas_mean[key][plot] = annual_season_mean(dat_interp)
        
    for plot in ['zonal_profile']:
        # calculate zonal profile
        dat_mm = dat[key].sel(lat=slice(-10,-40)).mean(dim='lat')
        time_mean[key][plot] = dat_mm.mean(dim='time')
        ann_mean[key][plot] = annual_mean(dat_mm)
        seas_mean[key][plot] = season_mean(dat_mm)
        ann_seas_mean[key][plot] = annual_season_mean(dat_mm)
        
print('Done.')


Reading in data...
Calculating time means for...
ctrl
cam+sam
Done.


In [None]:
## ++ CALCULATE SIGNIFICANCE OF MODEL RUN DIFFERENCES & MASK ++ ##

# initialize dictionaries
diff = collections.defaultdict(dict)
diff_mask = collections.defaultdict(dict)
ptvals = collections.defaultdict(dict)

## COMPARING CONTROL MODEL RUNS WITH E2.1 MODIFIED TOPOGRAPHY RUNS
# For simulations/datasets with the same number of samples
print('Significance testing for...')
#for key in ['cam+sam']:
#    print(f'{key}')
for plot in ['500mb','zonal_profile']:
    # initialize simulation sub-dicts
    diff[key][plot] = {}
    diff_mask[key][plot] = {}
    ptvals[key][plot] = {}
        
    for season in ['DJF','MAM','JJA','SON']:
        # calculate significance
        diff_, diff_mask_, ptvals_ = sigtest(ann_seas_mean['cam+sam'][plot][season], ann_seas_mean['ctrl'][plot][season],
                                            seas_mean['cam+sam'][plot].sel(season=season), seas_mean['ctrl'][plot].sel(season=season))
        diff[key][plot][season] = diff_
        diff_mask[key][plot][season] = diff_mask_
        ptvals[key][plot][season] = ptvals_
            
print('Done')

"""
## COMPARING CONTROL MODEL RUNS WITH OBS.
# need to re-grid obs data to E2.1 grid first
regrid= {}
seas_mean_regrid = {}
ann_seas_mean_regrid = collections.defaultdict(dict)
lats=time_mean['ctrl'].lat
lons=time_mean['ctrl'].lon

for key in ['obs']:
    print(f'{key}')
    # re-grid time mean obs
    regrid[key] = time_mean[key].interp(lat=lats, lon=lons, method='linear')
    # re-grid seasonal time-mean obs
    seas_mean_regrid[key] = seas_mean[key].interp(lat=lats, lon=lons, method='linear')
    # re-grid annual seasonal mean obs
    for season in seasons:
        ann_seas_mean_regrid[key][season] = ann_seas_mean[key][season].interp(lat=lats, lon=lons, method='linear')
        # calculate significance (not sure why sigtest2n requires adding the '.values' and sigtest function doesn't?)
        diff_, diff_mask_, ptvals_ = sigtest2n(ann_seas_mean_regrid[key][season].values, ann_seas_mean['ctrl'][season].values,
                                                seas_mean_regrid[key].sel(season=season).values, seas_mean['ctrl'].sel(season=season).values)
        diff[key][season] = diff_
        diff_mask[key][season] = diff_mask_
        ptvals[key][season] = ptvals_
        
print('Done.')
"""

Significance testing for...


In [11]:
## ++ CALCULATE DIFFERENCE OF CLIMATOLOGICAL FIELDS BETWEEN OBS & MODEL RUNS ++ ##

# initialize dictionaries
mdiff = {}
seas_mdiff = {}

print('Calculating field differences.')

for key in ['cam+sam']: 
    mdiff[key] = {}
    seas_mdiff[key] = {}
    for plot in ['500mb','zonal_profile']:
        mdiff[key][plot] = time_mean[key][plot] - time_mean['ctrl'][plot]
        seas_mdiff[key][plot] = seas_mean[key][plot] - seas_mean['ctrl'][plot]
print('Done.')

Calculating field differences.
Done.


In [14]:
seas_mdiff['cam+sam']['zonal_profile'].sel(season='MAM')

Unnamed: 0,Array,Chunk
Bytes,3.75 kiB,3.75 kiB
Shape,"(40, 24)","(40, 24)"
Dask graph,1 chunks in 48 graph layers,1 chunks in 48 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.75 kiB 3.75 kiB Shape (40, 24) (40, 24) Dask graph 1 chunks in 48 graph layers Data type float32 numpy.ndarray",24  40,

Unnamed: 0,Array,Chunk
Bytes,3.75 kiB,3.75 kiB
Shape,"(40, 24)","(40, 24)"
Dask graph,1 chunks in 48 graph layers,1 chunks in 48 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [None]:
# -------------------- #
#       Settings       #
# -------------------- #
# plot specs
text_kw={'color':  'k', 'weight': 'bold', 'ha':'center', 'va':'center'}
tkw = {'axis': 'both', 'direction':'in', 'labelsize': 10} 
grid_kw = {'color': 'black', 'weight': 'normal', 'size':10}
# var specs
season='MAM'
# colormap
cmap=cm.RdBu
vmin=-0.015
vmax=0.015
levels=np.linspace(vmin, vmax, 11)
norm=mpl.colors.BoundaryNorm(levels, cmap.N)
# map
trans=ccrs.PlateCarree()
proj=ccrs.PlateCarree(central_longitude=180)
xmin=150
xmax=325
map_bnds=[xmin, xmax, -45., 15.]
lats=seas_mdiff['cam+sam']['500mb'].lat
lons=seas_mdiff['cam+sam']['500mb'].lon
levs=seas_mdiff['cam+sam']['zonal_mean'].plm

# -------------------- #
#          Fig         #
# -------------------- #
fig=plt.figure(figsize=(10, 8), layout='constrained')

ax=plt.subplot2grid((2,1), (0, 0), colspan=1, rowspan=1, projection=proj)
ax.text(141.1,11.5,'a.', size=20, **text_kw, zorder=100, bbox=dict(facecolor='white', edgecolor='black'))

#ax.pcolormesh(lons, lats, diff_mask['cam+sam']['500mb'][season], cmap=cmap, vmin=vmin, vmax=vmax, transform=trans)
ax.pcolormesh(lons, lats, seas_mdiff['cam+sam']['500mb'].sel(season='MAM'), cmap=cmap, vmin=vmin, vmax=vmax, transform=trans)
ax.coastlines(color='k')
#ax.add_feature(cfeature.LAND, fc='k', zorder=10)
ax.set_extent(map_bnds, crs=trans)
gl=ax.gridlines(crs=trans, lw=.5, colors='black', alpha=1.0, linestyle='--', zorder=10, draw_labels=True)
gl.bottom_labels=True; gl.left_labels=True; gl.top_labels=False; gl.right_labels=False
gl.xformatter=LONGITUDE_FORMATTER; gl.yformatter=LATITUDE_FORMATTER
gl.xlabel_style=grid_kw; gl.ylabel_style=grid_kw
lon_bnds=np.array([150,325,325,150])
lat_bnds=np.array([-40,-40,-10,-10])
ring=LinearRing(list(zip(lon_bnds, lat_bnds)))
ax.add_geometries([ring], crs=trans, fc='none', ec='k', lw=2, linestyle='-', zorder=100)

ax=plt.subplot2grid((2,1), (1, 0), colspan=1, rowspan=1)
ax.text(321,51,'b.', size=20, **text_kw, zorder=100, bbox=dict(facecolor='white', edgecolor='black'))
#ax.contourf(lons, levs, diff_mask['cam+sam']['zonal_mean'][season], cmap=cmap, levels=levels, extend='both')
ax.contourf(lons, levs, seas_mdiff['cam+sam']['zonal_profile'].sel(season='MAM'), cmap=cmap, levels=levels, extend='both')
ax.set_xlabel('Longitude [°E]', weight='normal', size=12)
ax.set(xlim=[xmin, xmax], ylim=[1000,0])
ax.set_ylabel('[hPa]', weight='normal', size=12)
ax.set_yticks([800,600,400,200])
ax.tick_params(**tkw)



#"""
# add colorbar
cf=mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
cax=fig.add_axes([1.005, 0.125, 0.02, 0.775])
cbar=fig.colorbar(cf, orientation='vertical', extend='both', cax=cax)
cbar.set_label('[Pa/s]', y=1.09, labelpad=-40, rotation=0, size=10, fontweight='normal', ha='center')
cbar.ax.tick_params(labelsize=10)
for tick in cbar.ax.yaxis.get_major_ticks():
    tick.label2.set_fontweight('normal')
#"""

#plt.savefig(f'{opath}/omega.ca-sa.{season}.masked.pdf', transparent=True, bbox_inches='tight')