In [10]:
import os
import gc
import sys
import pdb
import copy
import glob
import psutil
import imageio
import calendar
import importlib
import numpy as np
import xarray as xr
import cmasher as cmr
import matplotlib as mpl
from datetime import datetime
from datetime import timedelta
import matplotlib.pyplot as plt
from matplotlib import gridspec
import matplotlib.ticker as ticker
import matplotlib.patches as patches
from memory_profiler import memory_usage
from matplotlib.patches import Rectangle
from cftime import DatetimeNoLeap as date
from matplotlib.colors import TwoSlopeNorm
from matplotlib.dates import DateFormatter
from matplotlib.ticker import FuncFormatter
from matplotlib.colors import LinearSegmentedColormap

plt.rcParams.update({
    'font.size'       : 14,   # Base font size
    'axes.labelsize'  : 14,   # Axis labels
    'xtick.labelsize' : 12,   # X-axis tick labels
    'ytick.labelsize' : 12,   # Y-axis tick labels
    'legend.fontsize' : 12,   # Legend font size
    'figure.titlesize': 16    # Figure title size
})
plt.ioff()
#mpl.use('Agg')
#%matplotlib agg

sys.path.insert(1, '/global/homes/j/jhollo/repos/climate_analysis/CLDERA/TEM/limvar_analysis_NERSC')
import plotting_utils as putil
import compute_ensemble_stats as ces
importlib.reload(putil)
importlib.reload(ces)
fix_dtint   = putil.adjust_10daily_integrated_tendency
shift_dtint = putil.shift_integrated_tendency
sig         = putil.filter_significance
cmn         = putil.get_cmap_norm

sys.path.insert(1, '/global/homes/j/jhollo/repos/climate_analysis/CLDERA/wavePaperFigs/util')
import nclcmaps as ncm

In [11]:
importlib.reload(putil)
importlib.reload(ces)

# ----- get impact data
kwargs = {'freq':'monthly', 'return_intersection':False}
T           = putil.get_variable('T', **kwargs)['impact']
u           = putil.get_variable('U', **kwargs)['impact']
utendepfd   = putil.get_variable('utendepfd', **kwargs)['impact']
utendresvel = putil.get_variable('utendresvel', **kwargs)['impact']
utendgw     = putil.get_variable('utendgw', **kwargs)['impact']
utenddiff   = putil.get_variable('utenddiff', **kwargs)['impact']
epfy        = putil.get_variable('epfy', **kwargs)['impact']
epfz        = putil.get_variable('epfz', **kwargs)['impact']
psitem      = putil.get_variable('psitem', **kwargs)['impact']
psitem_gp   = putil.get_variable('psitem_gradlat', **kwargs)['impact']
psitem_glat = putil.get_variable('psitem_gradp', **kwargs)['impact']
trop        = putil.get_variable('TROP_P', **kwargs)['ensmean']

# ----- get pval data
kwargs = {'freq':'monthly', 'return_intersection':False}
T_pval           = putil.get_variable('T', **kwargs)['pval']
u_pval           = putil.get_variable('U', **kwargs)['pval']
utendepfd_pval   = putil.get_variable('utendepfd', **kwargs)['pval']
utendresvel_pval = putil.get_variable('utendresvel', **kwargs)['pval']
utendgw_pval     = putil.get_variable('utendgw', **kwargs)['pval']
utenddiff_pval   = putil.get_variable('utenddiff', **kwargs)['pval']
epfy_pval        = putil.get_variable('epfy', **kwargs)['pval']
epfz_pval        = putil.get_variable('epfz', **kwargs)['pval']
psitem_pval      = putil.get_variable('psitem', **kwargs)['pval']
psitem_gp_pval   = putil.get_variable('psitem_gradlat', **kwargs)['pval']
psitem_glat_pval = putil.get_variable('psitem_gradp', **kwargs)['pval']

# ----- get counterfactual data
T_cf           = putil.get_variable('T', **kwargs)['cfmean']
u_cf           = putil.get_variable('U', **kwargs)['cfmean']
utendepfd_cf   = putil.get_variable('utendepfd', **kwargs)['cfmean']
utendresvel_cf = putil.get_variable('utendresvel', **kwargs)['cfmean']
utendgw_cf     = putil.get_variable('utendgw', **kwargs)['cfmean']
utenddiff_cf   = putil.get_variable('utenddiff', **kwargs)['cfmean']
epfy_cf        = putil.get_variable('epfy', **kwargs)['cfmean']
epfz_cf        = putil.get_variable('epfz', **kwargs)['cfmean']
psitem_cf      = putil.get_variable('psitem', **kwargs)['cfmean']
psitem_gp_cf   = putil.get_variable('psitem_gradlat', **kwargs)['cfmean']
psitem_glat_cf = putil.get_variable('psitem_gradp', **kwargs)['cfmean']

# ----- combine impact vars
allvars = [T, u, utendepfd, utendresvel, utendgw, utenddiff, epfy, epfz, psitem, psitem_gp, psitem_glat, trop]
impact = xr.merge(allvars)
# get months, years
calmonths, years = impact.time.dt.month.values, impact.time.dt.year.values
months = np.arange(1, len(calmonths)+1, 1)
impact   = impact.assign_coords(month=('time', months))
impact = impact.assign_coords(year=('time', years))
impact = impact.assign_coords(calmonth=('time', calmonths))
impact = impact.drop_vars('time')
impact = impact.rename(time='month')
impact = impact.set_xindex('month')
ym = np.vstack([impact.year.values, impact.calmonth.values]).T

# ----- combine pval vars
allvars = [T_pval, u_pval, utendepfd_pval, utendresvel_pval, utendgw_pval, utenddiff_pval, epfy_pval, epfz_pval, psitem_pval, psitem_gp_pval, psitem_glat_pval]
pval = xr.merge(allvars)
# get months, years
calmonths, years = pval.time.dt.month.values, pval.time.dt.year.values
months = np.arange(1, len(calmonths)+1, 1)
pval   = pval.assign_coords(month=('time', months))
pval = pval.assign_coords(year=('time', years))
pval = pval.assign_coords(calmonth=('time', calmonths))
pval = pval.drop_vars('time')
pval = pval.rename(time='month')
pval = pval.set_xindex('month')
ym = np.vstack([pval.year.values, pval.calmonth.values]).T

# ----- combine f vars
allvars = [T_cf, u_cf, utendepfd_cf, utendresvel_cf, utendgw_cf, utenddiff_cf, epfy_cf, epfz_cf, psitem_cf, psitem_gp_cf, psitem_glat_cf]
cf = xr.merge(allvars)
# get months, years
calmonths, years = cf.time.dt.month.values, cf.time.dt.year.values
months = np.arange(1, len(calmonths)+1, 1)
cf   = cf.assign_coords(month=('time', months))
cf = cf.assign_coords(year=('time', years))
cf = cf.assign_coords(calmonth=('time', calmonths))
cf = cf.drop_vars('time')
cf = cf.rename(time='month')
cf = cf.set_xindex('month')
ym = np.vstack([cf.year.values, cf.calmonth.values]).T

  impact = impact.rename(time='month')
  pval = pval.rename(time='month')
  cf = cf.rename(time='month')


In [109]:
importlib.reload(putil)
importlib.reload(ces)

def create_gradient_image(width=256, height=1):
    gradient = np.hstack([np.ones(int(width)), np.linspace(1, 0, int(width/2))])
    gradient = np.tile(gradient, (height, 1))
    return np.dstack((np.ones_like(gradient), np.ones_like(gradient), np.ones_like(gradient), gradient))

def power_law_shift(latw):
    # manually solved for this by fitting a cubic form. Really stupid
    x = latw/1.5
    return 1.83102*x**3 - 5.64225*x**2 + 6.31629*x - 2.5344

def make_figure(lat_range):
    latw = min([(latrange[1]-latrange[0])/90, 1.5])
    latw = max([latw, 0.7])
    fig = plt.figure()
    ax = fig.add_subplot(111)
    fig.set_size_inches(5*latw, 5)
    ax.set_ylim([0.3, 1000])
    putil.format_paxis(ax)
    putil.format_lataxis(ax)
    ax.set_ylabel('pressure [hPa]')
    putil.format_ticks(ax)
    total_title_shift  = 1.028
    initial_title_shift = 0.48
    total_title_shift   = total_title_shift/(latw/1.5)
    #initial_title_shift = power_law_shift(latw)
    return fig, ax, total_title_shift, initial_title_shift, latw
    

def make_plot(var, N, method, levels, cflevels, latrange=[-90,90], rvvec=True, epvec=True, dslat=1, dsp=1, epvscale=1e13, 
              logepvec=False, rvvscale=10, logrvvec=True, plotT=False, labelT=False, plotQBO=False, plotCFc=True, 
              plotSig=True, cnorm='twoslope'):
    
    N *= int(len(impact.month)/12) # input N is number of frames per year
    
    # ----- interpolate impact in time
    impact_month_interp = np.linspace(impact.month.values[0], impact.month.values[-1], N)
    impact_interp = impact.sel(lat=slice(latrange[0], latrange[1]))
    impact_interp_lin = impact_interp.interp(month=impact_month_interp, method='linear', assume_sorted=True)
    impact_interp = impact_interp.interp(month=impact_month_interp, method=method, assume_sorted=True)
    
    # ----- interpolate pval in time
    pval_interp = pval.sel(lat=slice(latrange[0], latrange[1]))
    pval_interp_lin = pval_interp.interp(month=impact_month_interp, method='linear', assume_sorted=True)
    pval_interp = pval_interp.interp(month=impact_month_interp, method=method, assume_sorted=True)
    
    # ----- interpolate counterfactual in time
    cf_interp = cf.sel(lat=slice(latrange[0], latrange[1]))
    cf_interp_lin = cf_interp.interp(month=impact_month_interp, method='linear', assume_sorted=True)
    cf_interp = cf_interp.interp(month=impact_month_interp, method=method, assume_sorted=True)
    
    # --- plotting settings
    ulw, ulwz, ucolor = 1, 1.33, 'k'
    vw, vhw, vhl      = 0.008,2.4,4   # vector arrow width, head width, and head length to quiver()
    epvcolor            = 'k'          # EP flux vector arrow color
    rvvcolor            = 'royalblue'          # residual velocity vector arrow color
    vecolor           = 'w'          # vector outline color
    vlw               = 0.75         # vector outline linewidth
    pbuff            = 20           # tropopause buffer if remove_trop_vectors=True
    tlw, tcolor = 6, 'grey'  # tropopause linewidth, color
    isiglw, isigcolor     = 2.5, 'w'     # linwdith and color for significance contours
    ihatch                = '////'      # insignificance hatching type
    ihatchtend            = '//////'    # insignificance hatching type for tendencies
    ihatchlw, ihatchcolor = 1.3, 'w'       # linewidth and color for insignificance hatching
    pcrit = 0.05         # threshold in p-value to determine significance
    cmap = {'U':'RdBu_r', 'psitem':'BrBG_r', 'utendepfd':'BrBG_r', 'utendgw':'BrBG_r', 'utendresvel':'BrBG_r', 'utenddiff':'BrBG_r'}[var]
    scaling = {'U':1, 'psitem':1, 'utendepfd':2592000, 'utendresvel':2592000, 'utendgw':2592000, 'utenddiff':2592000}
    cnorm = putil.get_cmap_norm(levels, cnorm)
    cfargs = {'levels':levels,'cmap':cmap, 'extend':'both', 'norm':cnorm}
    cfcargs  = {'levels':cflevels, 'colors':ucolor, 'linewidths':ulw}
    interp_vectors=True
    remove_trop_epvectors=False
    remove_trop_rvvectors=True
    tmcolor, tmlw = 'k', 1.1
    Tlevels = [230, 235, 240, 245]
    ctargs  = {'colors':tmcolor, 'linewidths':tmlw}
    mpl.rcParams['hatch.linewidth'] = ihatchlw
    mpl.rcParams['hatch.color']     = ihatchcolor
    
    monthstrs = ['']*len(ym)
    for i,d in enumerate(ym):
        if d[1] == 7: monthstrs[i] = f'July \'{str(ym[i][0])[2:]}'
        elif d[1] == 1: monthstrs[i] = f'Jan \'{str(ym[i][0])[2:]}'
        else: monthstrs[i] = ''
    sliding_calendar = spacer = '   '.join(monthstrs)
    title_chars = len(sliding_calendar)
    
    monthstrs = np.array([calendar.month_abbr[month] for month in impact.calmonth.values])
    mask = [month in [6, 9, 12, 3] for month in impact.calmonth.values]
    monthstrs = monthstrs[mask]
    sliding_calendar = spacer = '   '.join(monthstrs)
    title_chars = len(sliding_calendar)
    
    # --- make axes
    fig, ax, total_title_shift, initial_title_shift, latw = make_figure(latrange)
    
    # ---- scale EP flux vectors
    trop = impact_interp['TROP_P']
    epfy_, epfz_, epfyi_, epfzi_ = impact_interp['epfy'], impact_interp['epfz'], pval_interp['epfy'], pval_interp['epfz']
    if(interp_vectors):
        # interpolate the vector field to a uniform grid in lat-log(p)
        epfy_, epfz_, epfyi_, epfzi_ = putil.regrid_vectors_latp(epfy_, epfz_, usig=epfyi_, vsig=epfzi_)
    if(remove_trop_epvectors):
        # remove vectors in the troposphere so that they don't influence the length scaling
        epfy_ = putil.remove_troposphere(epfy_, trop, buffer=pbuff)
        epfz_ = putil.remove_troposphere(epfz_, trop, buffer=pbuff)
    # finally do EP flux vector scaling
    # This MUST be done last so that the vector scaling is not corrupted by later 
    # modifications to the figure or axes size or data limits
    Fx, Fy = putil.scale_EP_flux_vectors(fig, ax, epfy_, epfz_, dslat=dslat, dsp=dsp,log_vectors=logepvec,
                                         dsplog=False, interp_lat=False, interp_plev=False)
    
    # ---- scale residual velocity vectors
    glat_, gp_, glati_, gpi_ = impact_interp['psitem_gradlat'], impact_interp['psitem_gradp'], pval_interp['psitem_gradlat'], pval_interp['psitem_gradp']
    if(interp_vectors):
        glat_, gp_, glati_, gpi_ = putil.regrid_vectors_latp(glat_, gp_, usig=glati_, vsig=gpi_)
    if(remove_trop_rvvectors):
        # next remove vectors in the troposphere so that they don't influence the length scaling
        glat_ = putil.remove_troposphere(glat_, trop, buffer=pbuff)
        gp_   = putil.remove_troposphere(gp_, trop, buffer=pbuff)
    # finally get gradient-normal field from streamfunction
    gnx, gny = putil.streamfunction_gradient_normal(fig, ax, dfdlat=glat_, dfdp=gp_, dslat=dslat, dsp=dsp, 
                                                    dsplog=False, interp_lat=False, interp_plev=False, 
                                                    log_vectors=logrvvec)
    
    for i, month in enumerate(impact_interp.month.values):
        
        # configure plot shape, axes
        if(i > 0):
            fig, ax, total_title_shift, initial_title_shift, latw = make_figure(latrange)
        
        # make sliding title
        titlepos = [-0.25, 0.9, 1.5, 0.1]
        tax=fig.add_axes(titlepos, zorder=99)
        tax.set_axis_off()
        tax2=fig.add_axes(titlepos, zorder=99)
        tax2.set_axis_off()
        
        # make scrolling calendar title
        shift_fraction = i / (N - 1) * total_title_shift
        title = tax.text(initial_title_shift-shift_fraction, 1.07, sliding_calendar,
                         ha='left', va='center', fontsize=13, clip_on=False, 
                         transform = ax.transAxes)
        ax.set_title(str(int(impact_interp_lin.year.values[i]))+'\n\n', fontsize=14)
        
        # generate left and right title gradients
        gh, gw = 62, 250
        left_gradient_img  = create_gradient_image(width=gw, height=gh)
        right_gradient_img = create_gradient_image(width=gw, height=gh)[:, ::-1]
        # Add gradient overlays on the left and right sides
        frac = min([latw, 1.2])
        left_fade_ax  = tax.imshow(left_gradient_img, extent=[0, 0.4*frac, 0, 1], transform=tax.transAxes, alpha=1, zorder=99)
        right_fade_ax = tax.imshow(right_gradient_img, extent=[1-(0.4*frac), 1, 0, 1], transform=tax.transAxes, alpha=1, zorder=99)

        # --- plot variable
        cv  = impact_interp[var].sel(month=month) * scaling[var]
        cvc = ax.contourf(cv.lat, cv.plev, cv.T, **cfargs)
        # --- plot counterfactual variable
        if(plotCFc):
            cfcv  = cf_interp[var].sel(month=month) * scaling[var]
            cfcvc = ax.contour(cfcv.lat, cfcv.plev, cfcv.T, **cfcargs, zorder=99)
        # --- plot significance
        if(plotSig):
            cvsig = pval_interp[var].sel(month=month)
            ax.contour(cvsig.lat, cvsig.plev, cvsig.T, colors=isigcolor, levels=[pcrit], linewidths=isiglw)
            ax.contourf(cvsig.lat, cvsig.plev, cvsig.T, levels=[pcrit, cvsig.max()], hatches=[ihatch], alpha=0)
        
        # --- plot tropopause
        ax.plot(trop.lat, trop.sel(month=month)/100, color=tcolor, lw=tlw)
        
        # --- overlay temperature
        ct  = impact_interp['T'].sel(month=month)
        if(plotT):
            for k,temp in enumerate(Tlevels):
                ctc = ax.contour(ct.lat, ct.plev, ct.T, levels=[temp], **ctargs)
                if(labelT):
                    cx = (180/(len(Tlevels)+1)*(k+1))-90
                    cy = abs(ct.sel(plev=slice(0,100)).sel(lat=cx, method='nearest')-temp).idxmin().values
                    ax.clabel(ctc, [temp], inline=True, fmt='%d', fontsize=8, manual=[[cx, cy]])

        # --- overlay EP flux vector field
        alpha=1
        if(not epvec): alpha=0
        LAT, PLEV = np.meshgrid(Fx.lat, Fx.plev)
        qvep = ax.quiver(LAT, PLEV, Fx.sel(month=month).T, Fy.sel(month=month).T, width=vw, headwidth=vhw, headlength=vhl, headaxislength=vhl*0.9, alpha=alpha,
                          scale=epvscale, scale_units='inches', color=epvcolor, zorder=20, edgecolor=vecolor, linewidth=vlw)

        # --- overlay residual velocity vector field
        alpha=1
        if(not rvvec): alpha=0
        LAT, PLEV = np.meshgrid(gnx.lat, gnx.plev)
        qvrv = ax.quiver(LAT, PLEV, gnx.sel(month=month).T, gny.sel(month=month).T, scale=rvvscale, width=vw, headwidth=vhw, headlength=vhl, headaxislength=vhl*0.9, 
                          color=rvvcolor, zorder=20, edgecolor=vecolor, linewidth=vlw, alpha=alpha)
        
        # --- colorbar
        cbheight = 0.77
        cbpos = [0.93, (1-cbheight)/2-0.005, 0.04, cbheight]
        cblab = '$\\overline{{u}}$ impact [m/s]'
        cax   = fig.add_axes(cbpos)
        cb = fig.colorbar(cvc, cax=cax, orientation='vertical', location='right', 
                          format=FuncFormatter(putil.cbarfmt), extendrect=True)
        cb.set_label(cblab)
        cb.set_ticks(cvc.levels)
        
        if(plotQBO):
            # -------- draw QBO averaging box
            rect = patches.Rectangle((-5, 20), 10, 40, linewidth=1.5, edgecolor='m', facecolor='none')
            ax.add_patch(rect)
            # -------- draw QBO point on colorbar
            if(i==0): u = impact_interp['U'].isel(month=0, drop=False) * scaling[var]
            else:     u = impact_interp['U'].isel(month=slice(0,i)) * scaling[var]
            qbo = u.sel(lat=slice(-5,5)).mean('lat')
            qbo = qbo.sel(plev=slice(20,50)).mean('plev')
            if(var=='U'):
                if(i==0): cax.plot([0.5], [qbo], '>m', ms=7)
                else:     cax.plot([0.5], [qbo.isel(month=-1)], '>m', ms=7)
            # -------- plot QBO time series on second axes
            qpos = [1.2, (1-cbheight)/2-0.005, 0.9, cbheight]
            qax  = fig.add_axes(qpos)
            putil.format_ticks(qax, y='right', x='bottom')
            if(i > 0):
                qax.plot(impact_interp.month.isel(month=slice(0,i)), qbo, '-k', lw=2)
            qax.set_xlim([0, impact_interp.month.max().values])
            qax.set_ylim([-12, 12])
            qax.axhline(y=0, ls=':', color='k', lw=1)
            #putil.season_timeticks(qax, time, 'season')
        
        # -------- save, close
        epvstr = ['', '_EPVEC'][epvec]
        rvvstr = ['', '_RVVEC'][rvvec]
        Tstr   = ['', '_T'][plotT]
        Tlstr  = ['', 'labeled'][labelT]
        qbostr  = ['', '_QBO'][plotQBO]
        cfcstr  = ['', '_CFc'][plotCFc]
        sigstr  = ['', '_Pval'][plotSig]
        
        #plt.show()
        savedir = f'figs/impact_anim/frames/{var}_impact_{N}_{method}{epvstr}{rvvstr}{Tstr}{Tlstr}{qbostr}{cfcstr}{sigstr}'
        if not os.path.isdir(savedir):
            os.makedirs(savedir)
        plt.savefig('{}/{}.png'.format(savedir, str(i).zfill(4)), dpi=250, bbox_inches='tight', bbox_extra_artists=[ax, tax2])        
        plt.clf()
        plt.close(fig)
        #del fig, ax, cv, cvc, ct, cb
        gc.collect()
        print('{}/{}'.format(i+1, N), end='\r')

    print('done')

In [117]:
N = 200
method  = 'cubic'
plotSig = True
rvvec   = [False, False, False, False, True]
epvec   = [False, False, False, True, False]
plotT   = [False, True, False, False, False]
labelT  = [False, True, True, False, False]
plotCFc = [True, False, False, True, True]
plotQBO = [False, False, False, False, False]

for i in range(len(plotT)):
    if(i==0):continue
    if(plotT[i] or labelT[i] or rvvec[i]): continue
    latrange, cnorm, ulevels, cflevels = [-90, 90], 'twoslope', np.arange(-8, 8.1, 1), np.arange(-50, 71, 10)
    #latrange, cnorm, ulevels = [-25, 25], 'uneven', [-45, -30, -15, -5, -3, -2, -1, 0, 1, 2, 3, 5, 15, 30, 45]

    make_plot('U', N, method, latrange=latrange, levels=ulevels, cflevels=cflevels, cnorm=cnorm, epvec=epvec[i], rvvec=rvvec[i], 
              epvscale=1.5e13, rvvscale=40, dslat=8, dsp=6, logepvec=False, logrvvec=True, plotT=plotT[i], labelT=labelT[i], plotQBO=plotQBO[i],
             plotCFc=plotCFc[i], plotSig=plotSig)

done600


In [118]:
for i in range(len(plotT)):
    #if(i>0):continue
    var='U'
    epvstr = ['', '_EPVEC'][epvec[i]]
    rvvstr = ['', '_RVVEC'][rvvec[i]]
    Tstr = ['', '_T'][plotT[i]]
    Tlstr  = ['', 'labeled'][labelT[i]]
    qbostr  = ['', '_QBO'][plotQBO[i]]
    cfcstr  = ['', '_CFc'][plotCFc[i]]
    sigstr  = ['', '_Pval'][plotSig]
    
    fps = int(N*8/10)
    if(fps > 30):
        fps = 30

    vargs = {'mode':'I', 'fps':int(N*8/10), 'codec':'libx264'}
    j=0
    with imageio.get_writer(f'figs/impact_anim/U_impact_{N*3}_{method}{epvstr}{rvvstr}{Tstr}{Tlstr}{qbostr}{cfcstr}{sigstr}.mp4', **vargs) as writer:
        for file in sorted(glob.glob(f'figs/impact_anim/frames/{var}_impact_{N*3}_{method}{epvstr}{rvvstr}{Tstr}{Tlstr}{qbostr}{cfcstr}{sigstr}/*.png')):
            j+=1
            print(j, end='\r')
            image = imageio.v2.imread(file)
            image = image[:,int(image.shape[1]*0.18):int(image.shape[1]*0.865), :] # crop title
            writer.append_data(image)



4

[rawvideo @ 0x5501a40] Stream #0: not enough frames to estimate rate; consider increasing probesize


600



4

[rawvideo @ 0x5501a40] Stream #0: not enough frames to estimate rate; consider increasing probesize


600



4

[rawvideo @ 0x5501a40] Stream #0: not enough frames to estimate rate; consider increasing probesize


600