In [71]:
import numpy as np
import xarray as xr
import pandas as pd
#import metpy.calc as mpcalc
from glob import glob
import holoviews as hv
hv.extension('bokeh')
#%matplotlib inline

In [72]:
def save_output_monthly_chunks(start_year, end_year, var_data, output_name):

    for year in range(start_year, end_year + 1):
    
        for month in range(1,13):
        
            if month == 12:
                tem = var_data.sel(time = slice(str(year)+'-'+str(month).zfill(2)+'-01',str(year)+'-'+str(month).zfill(2)+'-31'))

            else:
                last_day_of_month = pd.Timestamp(f'{str(year)}-{str(month+1).zfill(2)}-01') - pd.Timedelta(days=0.5)
                tem = var_data.sel(time = slice(str(year)+'-'+str(month).zfill(2)+'-01',last_day_of_month))


            output_file = output_name+'.'+str(year)+str(month).zfill(2)+'.nc'
            tem.to_netcdf(output_file)

### Column integrate a variable that has nan values

In [73]:
def mass_weighted_vertical_integral_w_nan(variable_to_integrate, pressure_model_level_midpoint_Pa, pressure_model_level_interface_Pa, max_pressure_integral_array_Pa, min_pressure_integral_array_Pa):
    
    # Accepts both integers and arrays as min and max pressure limits
    
    # Define constants
    
    g = 9.8 # [m s^-2]
    
    # Set all model interfaces less than minimum pressure equal to minimum pressure, and more than maximum pressure to maximum pressure. This way, when you calculate "dp", these layers will not have mass.
    
    pressure_model_level_interface_Pa = pressure_model_level_interface_Pa.where(pressure_model_level_interface_Pa < max_pressure_integral_array_Pa, other = max_pressure_integral_array_Pa)
    pressure_model_level_interface_Pa = pressure_model_level_interface_Pa.where(pressure_model_level_interface_Pa > min_pressure_integral_array_Pa, other = min_pressure_integral_array_Pa)

    # Calculate delta pressure for each model level
    
    dp = pressure_model_level_midpoint_Pa.copy()
    dp.values = xr.DataArray(pressure_model_level_interface_Pa.isel(ilev = slice(1, len(pressure_model_level_interface_Pa.ilev))).values - pressure_model_level_interface_Pa.isel(ilev = slice(0, -1)).values) # Slice indexing is (inclusive start, exclusive stop)
    
    # Set dp = nan at levels missing data so mass of those levels not included in calculation of dp_total
    
    dp = dp.where(~xr.ufuncs.isnan(variable_to_integrate), drop=False, other=np.nan)

    # Mass weight each layer
    
    ci_variable = variable_to_integrate * dp / g
    
    # Integrate over levels
    
    ci_variable = ci_variable.sum('lev', min_count=1)
    dp_total = dp.sum('lev', min_count=1)
    
    # Set ci_variable to nan wherever dp_total is zero or nan
    
    ci_variable = ci_variable.where(~(dp_total==0), drop = False, other=np.nan)
    ci_variable = ci_variable.where(~xr.ufuncs.isnan(dp_total), drop = False, other=np.nan)
    
    # Calculate mass weighted vertical average over layer integrated over
    
    mwa_variable = ci_variable * g / dp_total
    
    return ci_variable, dp_total, mwa_variable

# Define input directories and file names
<br>
h0:  monthly mean  (1 time step per monthly file)
<br>
h1:  daily mean  (30 time steps per 30-day file)
<br>
h2:  6-hourly  (120 time steps per 30-day file)
<br>
h3:  3-hourly  (240 time steps per 30-day file)
<br>
h4:  hourly  (240 time steps per 10-day file)

In [74]:
# Years to analyze
start_year = (2010)
end_year = (2011)

################
###.  ERA5.  ###
################

# Atmosphere

input_file_string_specific_humidity = '/Projects/era5_regrid/2p5_benedict/ERA5_2.5deg_daily/shum.2p5.*.daily.nc' # Specific humidity
input_file_string_temperature = '/Projects/era5_regrid/2p5_benedict/ERA5_2.5deg_daily/air.2p5.*.daily.nc' # Temperature
input_file_string_surface_pressure = '/Projects/era5_regrid/2p5_benedict/ERA5_2.5deg_daily/pres.sfc.2p5.*.daily.nc' # Surface Pressure
input_file_string_v_wnd = '/Projects/era5_regrid/2p5_benedict/ERA5_2.5deg_daily/vwnd.2p5.*.daily.nc' # v wind
input_file_string_u_wnd = '/Projects/era5_regrid/2p5_benedict/ERA5_2.5deg_daily/uwnd.2p5.*.daily.nc' # u wind
input_file_string_hgt = '/Projects/era5_regrid/2p5_benedict/ERA5_2.5deg_daily/hgt.2p5.*.daily.nc' # geopotential height
input_file_string_omega = '/Projects/era5_regrid/2p5_benedict/ERA5_2.5deg_daily/omega.2p5.*.daily.nc' #vertical velocity

input_file_string_latent_flux = '/Projects/era5_regrid/2p5_benedict/ERA5_2.5deg_daily/mslhf.2p5.*.daily.nc'
input_file_string_sensible_flux = '/Projects/era5_regrid/2p5_benedict/ERA5_2.5deg_daily/msshf.2p5.*.daily.nc'
input_file_string_longwave_flux_surface = '/Projects/era5_regrid/2p5_benedict/ERA5_2.5deg_daily/msnlwrf.2p5.*.daily.nc'
input_file_string_longwave_flux_top = '/Projects/era5_regrid/2p5_benedict/ERA5_2.5deg_daily/mtnlwrf.2p5.*.daily.nc'
input_file_string_shortwave_flux_surface = '/Projects/era5_regrid/2p5_benedict/ERA5_2.5deg_daily/msnswrf.2p5.*.daily.nc'
input_file_string_shortwave_flux_top = '/Projects/era5_regrid/2p5_benedict/ERA5_2.5deg_daily/mtnswrf.2p5.*.daily.nc'

# Land
#input_file_string_land_frac = '/glade/work/bwolding/Datasets/Data_for_Glade/ERAi/land_sea_mask.erai.2p5.nc' # Land Fraction 

### Load in data and calculations

In [75]:
# Define constants

g = 9.8 #[m s^-2]
L = 2.26e6 #[J/kg]
cp = 1005 #[J/kg-K]
R_e = 6.378e6 #[m]
pi = 22/7

#########################################
# Define paths of files we wish to load #
#########################################
    
# glob expands paths with * to a list of files, like the unix shell #

paths_specific_humidity = glob(input_file_string_specific_humidity)
paths_temperature = glob(input_file_string_temperature)
paths_surface_pressure = glob(input_file_string_surface_pressure)
paths_v_wnd = glob(input_file_string_v_wnd)
paths_u_wnd = glob(input_file_string_u_wnd)
paths_hgt = glob(input_file_string_hgt)
paths_omega = glob(input_file_string_omega)

paths_lhf = glob(input_file_string_latent_flux)
paths_shf = glob(input_file_string_sensible_flux)
paths_longwave_surf = glob(input_file_string_longwave_flux_surface)
paths_longwave_top = glob(input_file_string_longwave_flux_top)
paths_shortwave_surf = glob(input_file_string_shortwave_flux_surface)
paths_shortwave_top = glob(input_file_string_shortwave_flux_top)

# Limit paths #
        
year_limited_paths_specific_humidity = []
year_limited_paths_temperature = []
year_limited_paths_u_wnd = []
year_limited_paths_v_wnd = []
year_limited_paths_hgt = []
year_limited_paths_omega = []

year_limited_paths_lhf = []
year_limited_paths_shf = []
year_limited_paths_longwave_surf = []
year_limited_paths_longwave_top = []
year_limited_paths_shortwave_surf = []
year_limited_paths_shortwave_top = []
        

for year in range(start_year, end_year + 1):
        
    print(year)
    
    # Define year strings #
        
    current_year_string = str(year)
            
    
    for string in paths_specific_humidity:
                        
        if (current_year_string in string):
                
            year_limited_paths_specific_humidity += [string]
            
    for string in paths_temperature:
                        
        if (current_year_string in string):
                
            year_limited_paths_temperature += [string]
            
    for string in paths_u_wnd:
                        
        if (current_year_string in string):
                
            year_limited_paths_u_wnd += [string]

    for string in paths_v_wnd:
                        
        if (current_year_string in string):
                
            year_limited_paths_v_wnd += [string]

    for string in paths_hgt:
                        
        if (current_year_string in string):
                
            year_limited_paths_hgt += [string]

    for string in paths_omega:
                        
        if (current_year_string in string):
                
            year_limited_paths_omega += [string]

    for string in paths_lhf:
                        
        if (current_year_string in string):
                
            year_limited_paths_lhf += [string]

    for string in paths_shf:
                        
        if (current_year_string in string):
                
            year_limited_paths_shf += [string]

    for string in paths_longwave_surf:
                        
        if (current_year_string in string):
                
            year_limited_paths_longwave_surf += [string]

    for string in paths_longwave_top:
                        
        if (current_year_string in string):
                
            year_limited_paths_longwave_top += [string]

    for string in paths_shortwave_surf:
                        
        if (current_year_string in string):
                
            year_limited_paths_shortwave_surf += [string]

    for string in paths_shortwave_top:
                        
        if (current_year_string in string):
                
            year_limited_paths_shortwave_top += [string]

#####################
####  Load Data  ####
#####################

# Data is "lazy loaded", nothing is actually loaded until we "look" at data in some way #

dataset_specific_humidity = xr.open_mfdataset(year_limited_paths_specific_humidity, combine="by_coords")
dataset_temperature = xr.open_mfdataset(year_limited_paths_temperature, combine="by_coords")
dataset_u_wnd = xr.open_mfdataset(year_limited_paths_u_wnd, combine="by_coords")
dataset_v_wnd = xr.open_mfdataset(year_limited_paths_v_wnd, combine="by_coords")
dataset_hgt = xr.open_mfdataset(year_limited_paths_hgt, combine="by_coords")
dataset_omega = xr.open_mfdataset(year_limited_paths_omega, combine="by_coords")

dataset_lhf = xr.open_mfdataset(year_limited_paths_lhf, combine="by_coords")
dataset_shf = xr.open_mfdataset(year_limited_paths_shf, combine="by_coords")
dataset_longwave_surf = xr.open_mfdataset(year_limited_paths_longwave_surf, combine="by_coords")
dataset_longwave_top = xr.open_mfdataset(year_limited_paths_longwave_top, combine="by_coords")
dataset_shortwave_surf = xr.open_mfdataset(year_limited_paths_shortwave_surf, combine="by_coords")
dataset_shortwave_top = xr.open_mfdataset(year_limited_paths_shortwave_top, combine="by_coords")

#####################
####  Load Data  ####
#####################
              
# Make data arrays, loading only the year of interest #
full_lat = dataset_temperature['lat']
full_lon = dataset_temperature['lon']
time = dataset_temperature['time']


#PS = dataset_surface_pressure['pres'].sel(time = slice(str(year)+'-01-01', str(year)+'-12-31'), lat = slice(15, -15)) # [Pa]
Q = dataset_specific_humidity['shum'].sel(time = slice(str(start_year)+'-01-01', str(end_year)+'-12-31'),lat = slice(15, -15), level = slice(70, 1000)) # [Kg/Kg]
T = dataset_temperature['air'].sel(time = slice(str(start_year)+'-01-01', str(end_year)+'-12-31'),lat = slice(15, -15), level = slice(70, 1000)) # [K]
U_wnd = dataset_u_wnd['uwnd'].sel(time = slice(str(start_year)+'-01-01', str(end_year)+'-12-31'),lat = slice(15, -15), level = slice(70, 1000)) # [K]
V_wnd = dataset_v_wnd['vwnd'].sel(time = slice(str(start_year)+'-01-01', str(end_year)+'-12-31'),lat = slice(15, -15), level = slice(70, 1000)) # [K]
HGT = dataset_hgt['hgt'].sel(time = slice(str(start_year)+'-01-01', str(end_year)+'-12-31'),lat = slice(15, -15), level = slice(70, 1000)) # [K]
OMEGA = dataset_omega['omega'].sel(time = slice(str(start_year)+'-01-01', str(end_year)+'-12-31'),lat = slice(15, -15), level = slice(70, 1000)) # [K]

LHF = dataset_lhf['mslhf'].sel(time = slice(str(start_year)+'-01-01', str(end_year)+'-12-31'),lat = slice(-15, 15))
SHF = dataset_shf['msshf'].sel(time = slice(str(start_year)+'-01-01', str(end_year)+'-12-31'),lat = slice(-15, 15))
LW_surf = dataset_longwave_surf['msnlwrf'].sel(time = slice(str(start_year)+'-01-01', str(end_year)+'-12-31'),lat = slice(-15, 15))
LW_top = dataset_longwave_top['mtnlwrf'].sel(time = slice(str(start_year)+'-01-01', str(end_year)+'-12-31'),lat = slice(-15, 15))
SW_surf = dataset_shortwave_surf['msnswrf'].sel(time = slice(str(start_year)+'-01-01', str(end_year)+'-12-31'),lat = slice(-15, 15))
SW_top = dataset_shortwave_top['mtnswrf'].sel(time = slice(str(start_year)+'-01-01', str(end_year)+'-12-31'),lat = slice(-15, 15))
# Actually load data #

Q.load()
T.load()
U_wnd.load()
V_wnd.load()
HGT.load()
OMEGA.load()

LHF.load()
SHF.load()
LW_surf.load()
LW_top.load()
SW_surf.load()
SW_top.load()

2010
2011


In [76]:
print(SW_surf)

<xarray.DataArray 'msnswrf' (time: 730, lat: 13, lon: 144)>
array([[[272.42447917, 303.70182292, 274.02083333, ..., 250.45963542,
         240.11328125, 294.24088542],
        [280.31119792, 269.8984375 , 258.88411458, ..., 256.77213542,
         236.17838542, 275.71614583],
        [268.96744792, 262.47395833, 249.19401042, ..., 260.81119792,
         271.40234375, 285.24348958],
        ...,
        [187.45052083, 211.08854167, 210.1171875 , ..., 211.32552083,
         205.01041667, 208.5625    ],
        [176.41536458, 198.95052083, 185.85807292, ..., 196.35546875,
         187.99609375, 185.99479167],
        [161.09895833, 155.49869792, 165.57291667, ..., 169.625     ,
         166.76432292, 165.85286458]],

       [[284.21875   , 276.70442708, 259.64713542, ..., 292.47265625,
         298.55338542, 261.23697917],
        [250.17838542, 208.43098958, 240.37630208, ..., 286.63020833,
         261.54947917, 272.07291667],
        [272.72005208, 267.40364583, 267.296875  , ..., 252.2

In [77]:
MSE = cp*T + L*Q + g*HGT

MSE

In [78]:

latitudes = np.radians(MSE['lat'].values)
longitudes = np.radians(MSE['lon'].values)

lon_grid, lat_grid = np.meshgrid(longitudes, latitudes)

dlon = np.gradient(lon_grid)
dlon_grid = dlon[1]
dlon_grid = R_e * np.cos(lat_grid) * dlon_grid

dlat = np.gradient(lat_grid)
dlat_grid = dlat[0]
dlat_grid = R_e * dlat_grid

dlon_grid.shape

(13, 144)

In [79]:



dMSE_dlat = np.gradient(MSE,axis = 2)
dMSE_dlon = np.gradient(MSE,axis = 3)

dlat_full = np.tile(dlat_grid, (MSE.shape[0],MSE.shape[1],1,1))
dlon_full = np.tile(dlon_grid, (MSE.shape[0],MSE.shape[1],1,1))

dMSE_dlat = dMSE_dlat/dlat_full
dMSE_dlon = dMSE_dlon/dlon_full

udMSE_dx = U_wnd*dMSE_dlon #zonal advection
vdMSE_dy = V_wnd*dMSE_dlat #meridional advection

HADV = udMSE_dx + vdMSE_dy



In [80]:
a = MSE.shape

In [81]:
dMSE_dp = MSE.differentiate(coord= 'level')/100

VADV = OMEGA*dMSE_dp

#VADV



In [82]:
# Specify the dimension representing pressure levels
pressure_dim = 'level'  # Replace with your dataset's dimension name

# Calculate the pressure spacing
pressure_levels = MSE[pressure_dim].values
pressure_spacing = np.diff(pressure_levels)*100/g  # Calculate the differences between adjacent pressure levels
pressure_spacing = np.insert(pressure_spacing, 0, pressure_spacing[0])  # Add the first spacing back for consistent dimensions

# Broadcast pressure_spacing to match the dimensions of data_variable
pressure_spacing_broadcasted = xr.DataArray(pressure_spacing, dims=pressure_dim)


# Compute the mass-weighted vertical integral
col_MSE = (MSE* pressure_spacing_broadcasted).sum(dim=pressure_dim)
col_HADV = (HADV* pressure_spacing_broadcasted).sum(dim=pressure_dim)
col_VADV = (VADV* pressure_spacing_broadcasted).sum(dim=pressure_dim)
col_HADV_zonal = (udMSE_dx* pressure_spacing_broadcasted).sum(dim=pressure_dim)
col_HADV_meridional = (vdMSE_dy* pressure_spacing_broadcasted).sum(dim=pressure_dim)



print(col_MSE)
#commenting to see what happens



<xarray.DataArray (time: 730, lat: 13, lon: 144)>
array([[[3.20377937e+09, 3.20301533e+09, 3.20201408e+09, ...,
         3.20796370e+09, 3.20589500e+09, 3.20477659e+09],
        [3.21882597e+09, 3.21417541e+09, 3.21286717e+09, ...,
         3.21717234e+09, 3.21958134e+09, 3.21934785e+09],
        [3.24029983e+09, 3.23976004e+09, 3.23884327e+09, ...,
         3.24465991e+09, 3.23785763e+09, 3.23831765e+09],
        ...,
        [3.26473342e+09, 3.26432509e+09, 3.26486936e+09, ...,
         3.25076663e+09, 3.25058596e+09, 3.25653120e+09],
        [3.25677832e+09, 3.25668830e+09, 3.26432144e+09, ...,
         3.26571978e+09, 3.26299156e+09, 3.26155393e+09],
        [3.26673695e+09, 3.26688090e+09, 3.27376693e+09, ...,
         3.25654803e+09, 3.26153626e+09, 3.26393423e+09]],

       [[3.20300858e+09, 3.20202110e+09, 3.20220989e+09, ...,
         3.20710427e+09, 3.20588357e+09, 3.20410747e+09],
        [3.21877478e+09, 3.21275991e+09, 3.20730390e+09, ...,
         3.22493821e+09, 3.221003

In [83]:
#### Surface and radiative flux calculation

surface_flux = - LHF - SHF

net_rad_flux = -LW_surf - SW_surf + SW_top + LW_top

In [84]:
### MSE tendency calculation

MSE_tend = col_MSE.differentiate(coord= 'time',datetime_unit="s")


### Save budget terms in monthly chunks

In [85]:
# Output directory for datasets
odir_datasets = '/Projects/era5_regrid/2p5_vijit/'

output_file_string_MSE = odir_datasets+'MSE'
output_file_string_colMSE = odir_datasets+'col_MSE'
save_output_monthly_chunks(start_year, end_year, MSE, output_file_string_MSE)
save_output_monthly_chunks(start_year, end_year, col_MSE, output_file_string_colMSE)

output_file_string_MSE_tend = odir_datasets+'MSE_tend'
save_output_monthly_chunks(start_year, end_year, MSE_tend, output_file_string_MSE_tend)

output_file_string_VADV = odir_datasets+'VADV'
output_file_string_colVADV = odir_datasets+'col_VADV'
save_output_monthly_chunks(start_year, end_year, VADV, output_file_string_VADV)
save_output_monthly_chunks(start_year, end_year, col_VADV, output_file_string_colVADV)

output_file_string_HADV = odir_datasets+'HADV'
output_file_string_colHADV = odir_datasets+'col_HADV'
save_output_monthly_chunks(start_year, end_year, HADV, output_file_string_HADV)
save_output_monthly_chunks(start_year, end_year, col_HADV, output_file_string_colHADV)

output_file_string_HADVzonal = odir_datasets+'HADV_zonal'
output_file_string_colHADVzonal = odir_datasets+'col_HADV_zonal'
save_output_monthly_chunks(start_year, end_year, udMSE_dx, output_file_string_HADVzonal)
save_output_monthly_chunks(start_year, end_year, col_HADV_zonal, output_file_string_colHADVzonal)

output_file_string_HADVmeridional = odir_datasets+'HADV_meridional'
output_file_string_colHADVmeridional = odir_datasets+'col_HADV_meridional'
save_output_monthly_chunks(start_year, end_year, vdMSE_dy, output_file_string_HADVmeridional)
save_output_monthly_chunks(start_year, end_year, col_HADV_meridional, output_file_string_colHADVmeridional)

output_file_string_surf_flux = odir_datasets+'SF'
save_output_monthly_chunks(start_year, end_year, surface_flux, output_file_string_surf_flux)

output_file_string_rad_flux_net = odir_datasets+'RadF_net'
save_output_monthly_chunks(start_year, end_year, net_rad_flux, output_file_string_rad_flux_net)
