In [1]:
#!/usr/bin/env python
# coding: utf-8

# ### Notebook to genereate CFADS from TC output
# 
# Assumes output is in a single netcdf file on pressure levels.
# 
# James Ruppert  
# jruppert@ou.edu  
# 4/23/22


from netCDF4 import Dataset
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib import ticker, cm
import os
import cmocean
from mask_tc_track_4d import mask_tc_track_4d
from mask_tc_track_3d import mask_tc_track_3d
from grl_cfads_functions import cfads_var_settings, cfads_var_calc, mask_edges, plt_diff_compare
import sys


In [2]:
# #### Variable selection

# Fill variable
ivar_all = ['thv','the']
# ivar_all = ['vmf']
nvar=np.size(ivar_all)

# #### Time selection
# ntall=[1,3,6]
ntall=[6]

# How many ensemble members
# nmem = 8 # number of ensemble members (1-10 have NCRF)
nmem = 8 # maria?
ind_list = [0,2,3,4,5,6,8,9] # maria


# #### Classification selection

# 0-non-raining, 1-conv, 2-strat, 3-Precipitating (excluding shallow), (-1 for off)
# kclass=[0,1,2,3]
# kclass=[1,2]
# kclass=[2,3]
kclass=[1,2,4,5] # deep, strat, storm (dc+strat+anvil), anvil

# #### Storm selection
# storm_all=['haiyan','maria']
# storm_all=['haiyan']
storm_all=['maria']
nstorm=np.size(storm_all)

# TC tracking
do_tc_track=True # Localize to TC track? Use whole domain otherwise
ptrack='850-600' # tracking pressure level
var_track = 'avor' # variable
rmax = 3

In [None]:
########## GO LOOPS GO ###############################################

########## VAR LOOP ###############################################

for ivar in range(nvar):
# for ivar in range(1):

  iplot = ivar_all[ivar]
  print("Variable: ",iplot)

  # Calculate anomaly as deviation from xy-mean
  do_prm_xy = 0
  # Calculate anomaly as time-increment
  do_prm_inc = 0
  # if (iplot == 'thv') or (iplot == 'qrad'):
  if (iplot == 'thv') or (iplot == 'the') or (iplot == 'lq') or (iplot == 'qv'):
      do_prm_xy = 1
  # Should be off for VMF
  if (iplot == 'vmf') or ('wpth' in iplot):
      do_prm_xy=0

  for istorm in range(nstorm):

    storm=storm_all[istorm]
    print("Storm: ",storm)

    # Tests to compare
    if storm == 'haiyan': # pick set of tests
      # tests = ['ctl','ncrf36h']
      # tests = ['ctl','STRATANVIL_OFF']
      # tests = ['ctl','STRATANVIL_ON']
      tests = ['ctl','STRAT_OFF']

    elif storm == 'maria': # pick set of tests
      # tests = ['ctl','ncrf36h']
      tests = ['ctl','ncrf48h']

    # #### Directories

    figdir = 
    figdiffdir = 
    main = 
    datdir2 = 

    enstag = str(nmem)
    memb0=1
    nums=np.arange(memb0,nmem+memb0,1); nums=nums.astype(str)
    nustr = np.char.zfill(nums, 2)
    # memb_all=np.char.add('memb_',nustr) # haiyan
    memb_all = ['memb_01','memb_03','memb_04','memb_05','memb_06','memb_07','memb_09','memb_10'] # maria


    ##### Get dimensions
    datdir = main+storm+'/'+memb_all[0]+'/'+tests[0]+'/'+datdir2
    varfil_main = Dataset(datdir+'T_HiRes.nc')
    nz = varfil_main.dimensions['level'].size
    nx1 = varfil_main.dimensions['lat'].size
    nx2 = varfil_main.dimensions['lon'].size
    pres = varfil_main.variables['pres'][:] # hPa
    varfil_main.close()

    # WRFOUT file list
    testdir = main+storm+'/'+memb_all[0]+'/'+tests[0]+'/'
    dirlist = os.listdir(testdir)
    subs="wrfout_d02"
    wrf_files = list(filter(lambda x: subs in x, dirlist))
    wrf_files.sort()
    wrfout = [testdir + s for s in wrf_files][0]
    varfil_main = Dataset(wrfout)
    lat = varfil_main.variables['XLAT'][:][0] # deg
    lon = varfil_main.variables['XLONG'][:][0] # deg
    varfil_main.close()

    # Get variable settings
    bins, fig_title, fig_tag, units_var, scale_mn, \
      units_mn, xrange_mn, xrange_mn2 = cfads_var_settings(iplot)
    nbin=np.shape(bins)[0]

    # Create axis of bin center-points
    bin_axis = (bins[np.arange(nbin-1)]+bins[np.arange(nbin-1)+1])/2

    if do_prm_xy == 1:
        fig_tag+='_xyp'
        fig_title+="$'$"# (xp)'
    if do_prm_inc == 1:
        fig_tag+='_tp'
        fig_title+=' (tp)'

    # Starting-read time step for model "restart" sensitivity tests
    t0_test=0


########## TIME LOOP ###############################################

    # #### Time selection
    i_nt=np.shape(ntall)[0]

    for knt in range(i_nt):
    # for knt in range(0,1):
    #for knt in range(3,i_nt+1):

      nt = ntall[knt]
      print(nt)
      hr_tag = str(np.char.zfill(str(nt), 2))
      print("Hour sample: ",hr_tag)
      
      ntest=2
      var_freq=np.ma.zeros((ntest,nbin-1,nz))
      var_freq_int=np.ma.zeros((ntest,nbin-1))

      # Create arrays to save ens members
      if do_prm_inc == 1:
        var_all = np.ma.zeros((ntest,nmem,nt+1,nz,nx1,nx2)) # time dim will be reduced to nt in the subtraction
      else:
        var_all = np.ma.zeros((ntest,nmem,nt,nz,nx1,nx2))
        # strat_all = np.ma.zeros((ntest,nmem,nt,nx1,nx2))
        strat_all = np.ma.zeros((ntest,nmem,nt,nx1,nx2))


########## TEST LOOP ###############################################

      for ktest in range(ntest):
      
        itest=tests[ktest]

        # This has been tested for corresponding time steps:
        #   t0=37,1 are the first divergent time steps in CTL,NCRF
        #   t0=25,1 are the first divergent time steps in NCRF,CRFON
        if itest == 'ctl':
          if tests[1] == 'ncrf36h':
            t0=36
          elif tests[1] == 'STRATANVIL_OFF':
            t0=36
          elif tests[1] == 'STRATANVIL_ON':
            t0=36
          elif tests[1] == 'STRAT_OFF':
            t0=36
          elif tests[1] == 'ncrf48h':
            t0=48
        elif itest == 'ncrf36h':
          t0=t0_test
        elif itest == 'STRATANVIL_OFF':
          t0=t0_test
        elif itest == 'STRATANVIL_ON':
          t0=t0_test
        elif itest == 'STRAT_OFF':
          t0=t0_test
        elif itest == 'ncrf48h':
          t0=t0_test
        elif itest == 'crfon':
          t0=t0_test

        if do_prm_inc == 0:
          t0+=1 # add one time step since NCRF(t=0) = CTL

        t1 = t0+nt
        if do_prm_inc == 1:
          t1+=1

        print('Running itest: ',itest)


########## ENS MEMBER LOOP ###############################################

        for imemb in range(nmem): 
      
          print('Running imemb: ',memb_all[imemb])
      
          datdir = main+storm+'/'+memb_all[imemb]+'/'+itest+'/'+datdir2
          print(datdir)
      
          # Localize to TC track
          if itest == 'ctl':
            track_file = datdir+'../../track_'+var_track+'_'+ptrack+'.nc'
          else:
            track_file = datdir+'../../track_'+var_track+'_'+ptrack+'.nc'

        # Two-dimensional variables

        # New stratiform scheme
          # varfil_main = Dataset(path) # memb x time x lat x lon; haiyan
          varfil_main = Dataset(path) # memb x time x lat x lon; maria
          strat = varfil_main.variables['type'][imemb,t0:t1,:,:]

        # Three-dimensional variables
          var = cfads_var_calc(iplot, datdir, pres, t0, t1, lat, lon)

        ### Process variable ##############################################

          # Calculate var' as anomaly from x-y-average, using large-scale (large-radius) var avg
          if do_prm_xy == 1:
            if do_tc_track:
              radius_ls=6 # Radius large-scale
              print(var.shape)
              var_ls = mask_tc_track_4d(track_file, radius_ls, var, lon, lat, t0, t1)
              var_ls_avg = np.ma.mean(var_ls,axis=(0,2,3))
            else:
              var_ls_avg = np.ma.mean(var,axis=(0,2,3))
            var -= var_ls_avg[np.newaxis,:,np.newaxis,np.newaxis]

          # Localize to TC track
          ## var = mask_tc_track(track_file, rmax, var, lon, lat, t0, t1)
          if do_tc_track:
            strat = mask_tc_track_3d(track_file, rmax, strat, lon, lat, t0, t1)
          else:
            strat = mask_edges(strat)

          # Save ens member
          var_all[ktest,imemb,:,:,:,:] = var
          strat_all[ktest,imemb,:,:,:] = np.squeeze(strat)

      # Calculate var' as time-increment: var[t] - var[t-1]
      if do_prm_inc == 1:
        var_all = var_all[:,:,1:,:,:,:] - var_all[:,:,:-1,:,:,:]

##### CLASSIFICATION LOOP ###############

      # 0-non-raining, 1-conv, 2-strat, 3-other/anvil, (-1 for off)
      nclass=np.shape(kclass)[0]
      for kstrat in range(nclass):
        
        istrat=kclass[kstrat]

        # Strat/Conv index subset
        if istrat == -1:
          fig_extra=''
        else:
          if istrat == 0:
            strattag='Nonrain'
          elif istrat == 1:
            # strattag='Conv'
            # strattag='Shall'
            strattag='Deep'
          elif istrat == 2:
            strattag='Strat'
            # strattag='Anvil'
          elif istrat == 3:
            strattag='Precip'
          elif istrat == 4:
            strattag='DC+Strat+Anvil'
          elif istrat == 5:
            strattag='Anvil'
          fig_extra='_'+strattag.lower()
          print("Strat tag: ",strattag)

#### Calculate frequency ##############################################

        #### Basic mean
        var_mn=np.zeros([ntest,nz])

        for ktest in range(ntest):

          # Classification-specific indices
          # if ((istrat != -1) & (istrat != 3)):
          #   ind = (strat_all[ktest] == istrat).nonzero()
          if (istrat == 1):
            # Deep convective only
            ind = ((strat_all[ktest] == 1)).nonzero()
          if (istrat == 2):
            # Stratiform only
            ind = (strat_all[ktest] == 4).nonzero()
          elif istrat == 3:
            # Precip: exclude non-precipitating
            ind = ((strat_all[ktest] != 0)).nonzero()
          elif istrat == 4:
            # Storm: DC + Strat + Anvil
            ind = ((strat_all[ktest] == 1) | (strat_all[ktest] == 4) | (strat_all[ktest] == 5)).nonzero()
          elif istrat == 5:
             # Anvil
             ind = ((strat_all[ktest] == 5)).nonzero()

          var_test = var_all[ktest]
          var_mn[ktest,:] = np.ma.mean(var_test[ind[0],ind[1],:,ind[2],ind[3]], axis=0)

          for iz in range(nz):
            var_slice = var_test[:,:,iz,...]
            var_strat = var_slice[ind]
            print(bins)
            count, placeholder = np.histogram(var_strat, bins=bins)
            var_freq[ktest,:,iz] = 100 * count / np.sum(count) # /(nx1*nx2*nt*nmem)


# ### Plotting routines ##############################################

        font = {'family' : 'sans-serif',
                'weight' : 'normal',
                'size'   : 16}

        matplotlib.rc('font', **font)


# ### Plot CFADs for both tests ################################

        for ktest in range(ntest):

          itest=tests[ktest]

          # pltvar = np.transpose(var_freq[ktest,:,:])
          pltvar = np.transpose(np.ma.masked_equal(var_freq[ktest,:,:],0))
          var_mn_plt = var_mn[ktest,:]*scale_mn

          fig, axd = plt.subplots(nrows=1, ncols=2, gridspec_kw={'width_ratios': [3, 1]},
                                  constrained_layout=True, figsize=(12, 8))

          ifig_title=fig_title+' ('+itest.upper()+')'
          if istrat != -1:
              ifig_title+=' ('+strattag+')'
          fig.suptitle(ifig_title)

          for col in range(2):
              
              ax = plt.subplot(1,2,1+col)

              ax.set_yscale('log')
              ax.invert_yaxis()
              ax.yaxis.set_major_formatter(ticker.ScalarFormatter())
              ax.tick_params(axis='both',length=7)
              ytick_loc=np.arange(900,0,-100)
              plt.yticks(ticks=ytick_loc)
              plt.ylim(np.max(pres), 100)#np.min(pres))

              ax.set_xlabel(units_var)


          ####### Fill contour ##############

              if col == 0:

                  ax.set_title('CFAD')
                  ax.set_ylabel('Pressure [hPa]')

                  if (iplot == 'vmf') or ('wpth' in iplot):
                      ax.set_xscale('symlog')
                      clevs=np.concatenate(([1e-2],np.arange(2,11,2)*1e-2,np.arange(2,11,2)*1e-1,np.arange(2,11,2)))
                      
                      locmin = ticker.SymmetricalLogLocator(base=10.0,linthresh=2,subs=np.arange(2,11,2)*0.1)
                      ax.xaxis.set_major_locator(locmin)
                      ticks=[1e-2,1e-1,1,1e1]
                  else: #if iplot == 'thv' or iplot == 'the':
                      clevs=[0.01,0.05,0.1,0.5,1,5,10,50]
                      # clevs=[1,2,5,10,25,50,100,200,500]
                      ticks=None
          
                  im = ax.contourf(bin_axis, pres, pltvar, clevs, norm=colors.LogNorm(),
                                  cmap=cmocean.cm.ice_r, alpha=1.0, extend='max', zorder=2)
                  
                  plt.xlim(np.min(bin_axis), np.max(bin_axis))
                  
                  # ax2 = ax.twinx()
                  # plt.ylim()
                  # plt.plot(bin_axis,var_freq_int[ktest,:])

                  cbar = plt.colorbar(im, ax=ax, shrink=0.75, ticks=ticks, format=ticker.LogFormatterMathtext())
                  cbar.ax.set_ylabel('%')


          ####### Mean profile ##############

              elif col == 1:
          
                  ax.set_title('Mean')
                  ax.yaxis.set_major_formatter(ticker.NullFormatter())
                  
                  ax.plot(var_mn_plt, pres, "-k", linewidth=2)
                  plt.xlim(xrange_mn)
                  plt.axvline(x=0,color='k',linewidth=0.5)
                  ax.set_xlabel(units_mn)

          plt.savefig(figdir+'cfad_'+fig_tag+fig_extra+'_ens'+enstag+'m_'+itest+'_'+hr_tag+'HiRes.png',dpi=200, facecolor='white', \
                      bbox_inches='tight', pad_inches=0.2)
          plt.close()



        
        # ### Plot difference CFAD ########################
        
        pltvar = np.transpose( var_freq[1,:,:] - var_freq[0,:,:] ) # TEST - CTL
        var_mn_plt = (var_mn[1,:] - var_mn[0,:])*scale_mn # TEST - CTL
        
        fig, axd = plt.subplots(nrows=1, ncols=2, gridspec_kw={'width_ratios': [3, 1]},
                                constrained_layout=True, figsize=(12, 8))
        
        if tests[1] == 'ncrf36h':
           testtag = 'NCRF'
           orderlabel = 'a) '
        elif tests[1] == 'STRATANVIL_OFF':
           testtag = 'SA-NCRF'
           orderlabel = 'c) '
        elif tests[1] == 'STRATANVIL_ON':
           testtag = 'C-NCRF'
           orderlabel = 'b) '
        elif tests[1] == 'STRAT_OFF':
           testtag = 'S-NCRF'
           orderlabel = 'd) '
        elif tests[1] == 'ncrf48h':
            testtag = 'NCRF'
            orderlabel = 'a) '
        
        ifig_title=orderlabel+fig_title+' ('+testtag.upper()+' - '+tests[0].upper()+')' # TEST - CTL
        if istrat != -1:
            ifig_title+=' ('+strattag+')'
        fig.suptitle(ifig_title)
        
        for col in range(2):
        
            ax = plt.subplot(1,2,1+col)

            ax.set_yscale('log')
            ax.invert_yaxis()
            ax.yaxis.set_major_formatter(ticker.ScalarFormatter())
            ax.tick_params(axis='both',length=7)
            ytick_loc=np.arange(900,0,-100)
            plt.yticks(ticks=ytick_loc)
            plt.ylim(np.max(pres), 100)#np.min(pres))

            ax.set_xlabel(units_var)
        
        
        ####### Fill contour ##############
        
            if col == 0:
        
                ax.set_title('CFAD')
                ax.set_ylabel('Pressure [hPa]')

                if (iplot == 'vmf') or ('wpth' in iplot):
                    ax.set_xscale('symlog')
                    clevsi=np.concatenate(([1e-2],np.arange(2,11,2)*1e-2,np.arange(2,11,2)*1e-1,np.arange(2,11,2)*1e-0))

                    locmin = ticker.SymmetricalLogLocator(base=10.0,linthresh=2,subs=np.arange(2,11,2)*0.1)
                    ax.xaxis.set_major_locator(locmin)
                else: #if iplot == 'thv' or iplot == 'the':
                    if iplot == 'qrad':
                      clevsi=np.concatenate(([1e-2],np.arange(2,11,2)*1e-2,np.arange(2,11,2)*1e-1,np.arange(2,11,2)*1e0,np.arange(2,11,2)*1e1))
                    else:
                      clevsi=np.concatenate(([1e-2],np.arange(2,11,2)*1e-2,np.arange(2,11,2)*1e-1,np.arange(2,11,2)*1e0))

                clevs = np.concatenate((-1*np.flip(clevsi),clevsi))

                im = ax.contourf(bin_axis, pres, pltvar, clevs, norm=colors.SymLogNorm(base=10,linthresh=clevsi[0],linscale=clevsi[0]),
                                cmap='RdBu_r', alpha=1.0, extend='both', zorder=2)

                plt.xlim(np.min(bin_axis), np.max(bin_axis))

                # if iplot == 'thv':
                plt.axvline(x=0,color='k',linewidth=1.)

                cbar = plt.colorbar(im, ax=ax, shrink=0.75, ticks=ticker.SymmetricalLogLocator(base=10.0, linthresh=.5),
                                    format=ticker.LogFormatterMathtext())
                cbar.ax.set_ylabel('%')

          ####### Mean profile ##############

            elif col == 1:

                ax.set_title('Mean')
                ax.yaxis.set_major_formatter(ticker.NullFormatter())

                ax.plot(var_mn_plt, pres, "-k", linewidth=2)
                # plt.xticks(np.arange(xrange_mn2[0],xrange_mn2[1]+0.5,0.5), rotation=45) # VMF
                plt.xticks(np.arange(xrange_mn2[0],xrange_mn2[1]+0.3,0.3), rotation=45) # thetas
                plt.xlim(xrange_mn2)
                plt.axvline(x=0,color='k',linewidth=0.5)
                ax.set_xlabel(units_mn)

        difftag='diffSwitch' # TEST - CTL
        switchtag = ''
        if tests[0] == 'crfon': difftag+='v2'
        plt.savefig(figdiffdir+'cfad_'+fig_tag+fig_extra+'_ens'+enstag+'m_'+difftag+'_'+hr_tag+'ctlvs'+tests[1]+'HiRes_850track.png',dpi=200, facecolor='white', \
                    bbox_inches='tight', pad_inches=0.2)
        plt.show()
        plt.close()

        # plt_diff_compare(var_mn_plt)