In [54]:
import os
import gc
import sys
import pdb
import copy
import glob
import psutil
import imageio
import calendar
import importlib
import numpy as np
import xarray as xr
import cmasher as cmr
import matplotlib as mpl
from datetime import datetime
from datetime import timedelta
import matplotlib.pyplot as plt
from matplotlib import gridspec
import matplotlib.ticker as ticker
import matplotlib.patches as patches
from memory_profiler import memory_usage
from matplotlib.patches import Rectangle
from cftime import DatetimeNoLeap as date
from matplotlib.colors import TwoSlopeNorm
from matplotlib.dates import DateFormatter
from matplotlib.ticker import FuncFormatter
from matplotlib.colors import LinearSegmentedColormap

plt.rcParams.update({
    'font.size'       : 14,   # Base font size
    'axes.labelsize'  : 14,   # Axis labels
    'xtick.labelsize' : 12,   # X-axis tick labels
    'ytick.labelsize' : 12,   # Y-axis tick labels
    'legend.fontsize' : 12,   # Legend font size
    'figure.titlesize': 16    # Figure title size
})
plt.ioff()
#mpl.use('Agg')
#%matplotlib agg

sys.path.insert(1, '/global/homes/j/jhollo/repos/climate_analysis/CLDERA/TEM/limvar_analysis_NERSC')
import plotting_utils as putil
import compute_ensemble_stats as ces
importlib.reload(putil)
importlib.reload(ces)
fix_dtint   = putil.adjust_10daily_integrated_tendency
shift_dtint = putil.shift_integrated_tendency
sig         = putil.filter_significance
cmn         = putil.get_cmap_norm

sys.path.insert(1, '/global/homes/j/jhollo/repos/climate_analysis/CLDERA/wavePaperFigs/util')
import nclcmaps as ncm

In [57]:
importlib.reload(putil)
importlib.reload(ces)

# ----- get data
kwargs = {'freq':'monthly', 'return_intersection':False, 'return_members':False}
T           = putil.get_variable('T', **kwargs)['cfmean']
u           = putil.get_variable('U', **kwargs)['cfmean']
utendepfd   = putil.get_variable('utendepfd', **kwargs)['cfmean']
utendresvel = putil.get_variable('utendresvel', **kwargs)['cfmean']
utendgw     = putil.get_variable('utendgw', **kwargs)['cfmean']
utenddiff   = putil.get_variable('utenddiff', **kwargs)['cfmean']
epfy        = putil.get_variable('epfy', **kwargs)['cfmean']
epfz        = putil.get_variable('epfz', **kwargs)['cfmean']
psitem      = putil.get_variable('psitem', **kwargs)['cfmean']
psitem_gp   = putil.get_variable('psitem_gradlat', **kwargs)['cfmean']
psitem_glat = putil.get_variable('psitem_gradp', **kwargs)['cfmean']
trop        = putil.get_variable('TROP_P', **kwargs)['cfmean']

# ----- combine data, make month coordinate
allvars = [T, u, utendepfd, utendresvel, utendgw, utenddiff, epfy, epfz, psitem, psitem_gp, psitem_glat, trop]
cf = xr.merge(allvars)
# get months, years
calmonths, years = cf.time.dt.month.values, cf.time.dt.year.values
months = np.arange(1, len(calmonths)+1, 1)
cf   = cf.assign_coords(month=('time', months))
cf = cf.assign_coords(year=('time', years))
cf = cf.assign_coords(calmonth=('time', calmonths))
cf = cf.drop_vars('time')
cf = cf.rename(time='month')
cf = cf.set_xindex('month')
ym = np.vstack([cf.year.values, cf.calmonth.values]).T

  cf = cf.rename(time='month')


In [4]:
importlib.reload(putil)
importlib.reload(ces)

def create_gradient_image(width=256, height=1):
    gradient = np.hstack([np.ones(int(width)), np.linspace(1, 0, int(width/2))])
    gradient = np.tile(gradient, (height, 1))
    return np.dstack((np.ones_like(gradient), np.ones_like(gradient), np.ones_like(gradient), gradient))

def power_law_shift(latw):
    # manually solved for this by fitting a cubic form. Really stupid
    x = latw/1.5
    return 1.83102*x**3 - 5.64225*x**2 + 6.31629*x - 2.5344

def make_figure(lat_range, qbo):
    latw = min([(latrange[1]-latrange[0])/90, 1.5])
    latw = max([latw, 0.7])
    fig = plt.figure()
    fig.set_size_inches(5*latw, 5)
    ax = fig.add_subplot(111)
    ax.set_ylim([0.3, 1000])
    putil.format_paxis(ax)
    putil.format_lataxis(ax)
    ax1.set_ylabel('pressure [hPa]')
    putil.format_ticks(ax)
    return fig, ax
    

def make_plot(var, N, method, levels, latrange=[-90,90], rvvec=True, epvec=True, dslat=1, dsp=1, epvscale=1e13, 
              logepvec=False, rvvscale=10, logrvvec=True, plotT=False, labelT=False, plotQBO=False, cnorm='twoslope'):
    
    # ----- interpolate in time
    N *= 8 # input N is number of frames per year
    cf_month_interp = np.linspace(cf.month.values[0], cf.month.values[-1], N)
    cf_interp = cf.sel(lat=slice(latrange[0], latrange[1]))
    cf_interp_lin = cf_interp.interp(month=cf_month_interp, method='linear', assume_sorted=True)
    cf_interp = cf_interp.interp(month=cf_month_interp, method=method, assume_sorted=True)
    #time = timedelta(cf_interp.month)
    
    # --- plotting settings
    ulw, ulwz, ucolor = 0.75, 1.33, 'k'
    vw, vhw, vhl      = 0.008,2.4,4   # vector arrow width, head width, and head length to quiver()
    epvcolor            = 'k'          # EP flux vector arrow color
    rvvcolor            = 'royalblue'          # residual velocity vector arrow color
    vecolor           = 'w'          # vector outline color
    vlw               = 0.75         # vector outline linewidth
    pbuff            = 20           # tropopause buffer if remove_trop_vectors=True
    tlw, tcolor = 6, 'grey'  # tropopause linewidth, color
    cmap = {'U':'RdYlBu_r', 'psitem':'RdYlBu_r', 'utendepfd':'BrBG_r', 'utendgw':'BrBG_r', 'utendresvel':'BrBG_r', 'utenddiff':'BrBG_r'}[var]
    scaling = {'U':1, 'psitem':1, 'utendepfd':2592000, 'utendresvel':2592000, 'utendgw':2592000, 'utenddiff':2592000}
    cnorm = putil.get_cmap_norm(levels, cnorm)
    cfargs = {'levels':levels,'cmap':cmap, 'extend':'both', 'norm':cnorm}
    cargs  = {'levels':levels, 'colors':ucolor, 'linewidths':ulw}
    interp_vectors=True
    remove_trop_epvectors=False
    remove_trop_rvvectors=True
    tmcolor, tmlw = 'k', 1.1
    Tlevels = [230, 235, 240, 245]
    ctargs  = {'colors':tmcolor, 'linewidths':tmlw}
    
    monthstrs = ['']*len(ym)
    for i,d in enumerate(ym):
        if d[1] == 7: monthstrs[i] = f'July \'{str(ym[i][0])[2:]}'
        elif d[1] == 1: monthstrs[i] = f'Jan \'{str(ym[i][0])[2:]}'
        else: monthstrs[i] = ''
    sliding_calendar = spacer = '   '.join(monthstrs)
    title_chars = len(sliding_calendar)
    
    # --- make axes
    fig, ax = make_figure(latrange, plotQBO)
    
    # ---- scale EP flux vectors
    trop = cf_interp['TROP_P']
    epfy_, epfz_ = cf_interp['epfy'], cf_interp['epfz']
    if(interp_vectors):
        # interpolate the vector field to a uniform grid in lat-log(p)
        epfy_, epfz_ = putil.regrid_vectors_latp(epfy_, epfz_)
    if(remove_trop_epvectors):
        # remove vectors in the troposphere so that they don't influence the length scaling
        epfy_ = putil.remove_troposphere(epfy_, trop, buffer=pbuff)
        epfz_ = putil.remove_troposphere(epfz_, trop, buffer=pbuff)
    # finally do EP flux vector scaling
    # This MUST be done last so that the vector scaling is not corrupted by later 
    # modifications to the figure or axes size or data limits
    Fx, Fy = putil.scale_EP_flux_vectors(fig, ax, epfy_, epfz_, dslat=dslat, dsp=dsp,log_vectors=logepvec,
                                         dsplog=False, interp_lat=False, interp_plev=False)
    
    # ---- scale residual velocity vectors
    glat_, gp_ = cf_interp['psitem_gradlat'], cf_interp['psitem_gradp']
    if(interp_vectors):
        glat_, gp_, = putil.regrid_vectors_latp(glat_, gp_)
    if(remove_trop_rvvectors):
        # next remove vectors in the troposphere so that they don't influence the length scaling
        glat_ = putil.remove_troposphere(glat_, trop, buffer=pbuff)
        gp_   = putil.remove_troposphere(gp_, trop, buffer=pbuff)
    # finally get gradient-normal field from streamfunction
    gnx, gny = putil.streamfunction_gradient_normal(fig, ax, dfdlat=glat_, dfdp=gp_, dslat=dslat, dsp=dsp, 
                                                    dsplog=False, interp_lat=False, interp_plev=False, 
                                                    log_vectors=logrvvec)
    
    for i, month in enumerate(cf_interp.month.values):
        
        # configure plot shape, axes
        if(i > 0):
            fig, ax = make_figure(latrange, plotQBO)
            
        ax.set_title(str(int(cf_interp_lin.year.values[i])), fontsize=14)

        # --- plot variable
        cv  = cf_interp[var].sel(month=month) * scaling[var]
        cvc = ax.contourf(cv.lat, cv.plev, cv.T, **cfargs)

        # --- plot tropopause
        ax.plot(trop.lat, trop.sel(month=month)/100, color=tcolor, lw=tlw)
        
        # --- overlay temperature
        ct  = cf_interp['T'].sel(month=month)
        if(plotT):
            for k,temp in enumerate(Tlevels):
                ctc = ax.contour(ct.lat, ct.plev, ct.T, levels=[temp], **ctargs)
                if(labelT):
                    cx = (180/(len(Tlevels)+1)*(k+1))-90
                    cy = abs(ct.sel(plev=slice(0,100)).sel(lat=cx, method='nearest')-temp).idxmin().values
                    ax.clabel(ctc, [temp], inline=True, fmt='%d', fontsize=8, manual=[[cx, cy]])

        # --- overlay EP flux vector field
        alpha=1
        if(not epvec): alpha=0
        LAT, PLEV = np.meshgrid(Fx.lat, Fx.plev)
        qvep = ax.quiver(LAT, PLEV, Fx.sel(month=month).T, Fy.sel(month=month).T, width=vw, headwidth=vhw, headlength=vhl, headaxislength=vhl*0.9, alpha=alpha,
                          scale=epvscale, scale_units='inches', color=epvcolor, zorder=20, edgecolor=vecolor, linewidth=vlw)

        # --- overlay residual velocity vector field
        alpha=1
        if(not rvvec): alpha=0
        LAT, PLEV = np.meshgrid(gnx.lat, gnx.plev)
        qvrv = ax.quiver(LAT, PLEV, gnx.sel(month=month).T, gny.sel(month=month).T, scale=rvvscale, width=vw, headwidth=vhw, headlength=vhl, headaxislength=vhl*0.9, 
                          color=rvvcolor, zorder=20, edgecolor=vecolor, linewidth=vlw, alpha=alpha)
        
        # --- colorbar
        cbheight = 0.77
        cbpos = [0.93, (1-cbheight)/2-0.005, 0.04, cbheight]
        cblab = 'counterfactual $\\overline{{u}}$ [m/s]'
        cax   = fig.add_axes(cbpos)
        cb = fig.colorbar(cvc, cax=cax, orientation='vertical', location='right', 
                          format=FuncFormatter(putil.cbarfmt), extendrect=True)
        cb.set_label(cblab)
        cb.set_ticks(cvc.levels)
        
        if(plotQBO):
            # -------- draw QBO averaging box
            rect = patches.Rectangle((-5, 20), 10, 40, linewidth=1.5, edgecolor='m', facecolor='none')
            ax.add_patch(rect)
            # -------- draw QBO point on colorbar
            if(i==0): u = cf_interp['U'].isel(month=0, drop=False) * scaling[var]
            else:     u = cf_interp['U'].isel(month=slice(0,i)) * scaling[var]
            qbo = u.sel(lat=slice(-5,5)).mean('lat')
            qbo = qbo.sel(plev=slice(20,50)).mean('plev')
            if(var=='U'):
                if(i==0): cax.plot([0.5], [qbo], '>m', ms=7)
                else:     cax.plot([0.5], [qbo.isel(month=-1)], '>m', ms=7)
            # -------- plot QBO time series on second axes
            qpos = [1.2, (1-cbheight)/2-0.005, 0.9, cbheight]
            qax  = fig.add_axes(qpos)
            putil.format_ticks(qax, y='right', x='bottom')
            if(i > 0):
                qax.plot(cf_interp.month.isel(month=slice(0,i)), qbo, '-k', lw=2)
            qax.set_xlim([0, cf_interp.month.max().values])
            qax.set_ylim([-12, 12])
            qax.axhline(y=0, ls=':', color='k', lw=1)
            #putil.season_timeticks(qax, time, 'season')
        
        # -------- save, close
        epvstr = ['', '_EPVEC'][epvec]
        rvvstr = ['', '_RVVEC'][rvvec]
        Tstr   = ['', '_T'][plotT]
        Tlstr  = ['', 'labeled'][labelT]
        qbostr  = ['', '_QBO'][plotQBO]
        
        #plt.show()
        savedir = 'figs/cf_anim/frames/{}_cf_{}_{}{}{}{}{}{}'.format(var, N, method, epvstr, rvvstr, Tstr, Tlstr, qbostr)
        if not os.path.isdir(savedir):
            os.makedirs(savedir)
        plt.savefig('{}/{}.png'.format(savedir, str(i).zfill(4)), dpi=250, bbox_inches='tight')#, bbox_extra_artists=[ax])        
        plt.clf()
        plt.close(fig)
        #del fig, ax, cv, cvc, ct, cb
        gc.collect()
        print('{}/{}'.format(i+1, N), end='\r')

    print('done')

In [7]:
N = 100
method  = 'cubic'
rvvec   = [False, False, False, False, False, True]
epvec   = [False, False, False, False, True, False]
plotT   = [False, False, True, True, False, False]
labelT  = [False, False, False, True, False, False]
plotQBO = [True, False, False, False, False, False]
rvvec   = [False]
epvec   = [False]
plotT   = [False]
labelT  = [False]
plotQBO = [True]

for i in range(len(plotT)):
    if(i>0):continue
    #latrange, cnorm, ulevels = [-90, 90], 'twoslope', np.arange(-50, 71, 10)
    latrange, cnorm, ulevels = [-25, 25], 'uneven', [-45, -30, -15, -5, -3, -2, -1, 0, 1, 2, 3, 5, 15, 30, 45]

    make_plot('U', N, method, latrange=latrange, levels=ulevels, cnorm=cnorm, epvec=epvec[i], rvvec=rvvec[i], 
              epvscale=1.5e15, rvvscale=40, dslat=8, dsp=6, logepvec=False, logrvvec=True, plotT=plotT[i], labelT=labelT[i], plotQBO=plotQBO[i])

done800


In [8]:
for i in range(len(plotT)):
    if(i>0):continue
    var='U'
    epvstr = ['', '_EPVEC'][epvec[i]]
    rvvstr = ['', '_RVVEC'][rvvec[i]]
    Tstr = ['', '_T'][plotT[i]]
    Tlstr  = ['', 'labeled'][labelT[i]]
    qbostr  = ['', '_QBO'][plotQBO[i]]
    vargs = {'mode':'I', 'fps':int(N*8/10), 'codec':'libx264'}
    j=0
    with imageio.get_writer(f'figs/cf_anim/U_cf_{N*8}_{method}{epvstr}{rvvstr}{Tstr}{Tlstr}.mp4', **vargs) as writer:
        for file in sorted(glob.glob(f'figs/cf_anim/frames/{var}_cf_{N*8}_{method}{epvstr}{rvvstr}{Tstr}{Tlstr}{qbostr}/*.png')):
            j+=1
            print(j, end='\r')
            image = imageio.v2.imread(file)
            image = image[:,int(image.shape[1]*0):int(image.shape[1]*1), :] # crop title
            writer.append_data(image)



800

In [15]:
cf.plev