### TC energy budget anomaly
Gets the energy budget anomaly for a given TC.

In [1]:
import xarray as xr
import cartopy, cartopy.crs as ccrs, matplotlib, matplotlib.pyplot as plt
import importlib
import cftime
import functools
import numpy as np
import os
import pandas as pd
import random
from multiprocessing import Pool

import utilities
import visualization

importlib.reload(utilities);

In [2]:
def get_filename_year(filename: str) -> int:
    delimiter_string = 'storm_ID'
    filename_year = int(filename.split(f'{delimiter_string}-')[-1].split('-')[0])
    return filename_year

In [3]:
def get_filename_intensity(filename: str,
                           delimiter_string: str) -> int:
    filename_intensity = int(filename.split(f'{delimiter_string}-')[-1].split('.')[0])
    return filename_intensity

In [4]:
def field_correction(model_name: str, 
                     dataset: xr.DataArray,
                     field_name: str):

    if field_name in ['precip', 'evap', 'p-e']:
        # Conversion of total precipitation per hour to daily precipitation rate for ERA5 data, instantaneous rate to daily for GFDL GCM data
        factor = (1 / 3600) * 1000 * 86400 if model_name == 'ERA5' else 86400
    else:
        factor = 1
    dataset[field_name] = dataset[field_name] * factor

    return dataset

In [5]:
def get_time_window(dataset: xr.Dataset,
                    timestamp,
                    window_day_size: int):

    ''' Filter a dataset by a given timestamp +/- a specific number of days. '''

    # Obtain day of year for the timestamp
    timestamp_day_of_year = timestamp.dayofyr if 'cftime' in str(type(timestamp)) else timestamp.dt.dayofyear
    # Dataset time array days of year (handled differently by time object type)
    dataset_day_of_year = dataset.time.dt.dayofyear if 'cftime' in str(type(timestamp)) else dataset.time.dt.dayofyear
    # Get start and end days of year
    start_day_of_year, end_day_of_year = timestamp_day_of_year - window_day_size, timestamp_day_of_year + window_day_size
    # Mask by window from start_day_of_year to end_day_of_year
    window = (dataset_day_of_year >= start_day_of_year) & (dataset_day_of_year <= end_day_of_year)
    # Mask data by the window
    dataset_window = dataset.sel(time=window)

    return dataset_window

In [6]:
def grid_check(dataset: xr.Dataset) -> tuple[float, float]:

    ''' Perform grid checks and get grid spacing for a generic xArray coordinate system. '''

    # Ensure necessary basis vector dimensions are available
    assert ('grid_xt' in dataset.dims) and ('grid_yt' in dataset.dims)
    # Get differences in grid spacing along each vector
    d_grid_yt = dataset['grid_yt'].diff(dim='grid_yt')
    d_grid_xt = dataset['grid_xt'].diff(dim='grid_xt')
    # Ensure that the differences are equivalent for all indices to ensure equal spacing
    grid_tolerance = 1e-6
    assert sum(d_grid_xt.diff(dim='grid_xt') < grid_tolerance) == len(d_grid_xt) - 1, 'Grid is irregular along the `grid_xt` axis.'
    assert sum(d_grid_yt.diff(dim='grid_yt') < grid_tolerance) == len(d_grid_yt) - 1, 'Grid is irregular along the `grid_yt` axis.'
    # Get grid spacing along each direction
    dx, dy = d_grid_xt.isel(grid_xt=0).item(), d_grid_yt.isel(grid_yt=0).item()

    return abs(dx), abs(dy)

In [7]:
def grid_interpolation(working_grid: xr.Dataset,
                       reference_grid: xr.DataArray,
                       diagnostic: bool=False) -> xr.DataArray:

    ''' 
    Method to generate uniform interpolation basis vectors for a TC-centered grid. 
    Boolean `equal_number_of_points` is used to ensure equal grid point numbers, and is optional.
    '''

    # Ensure necessary basis vector dimensions are available
    assert ('grid_xt' in working_grid.dims) and ('grid_yt' in working_grid.dims)
    # Get differences in grid spacing along each vector
    d_grid_yt = reference_grid['grid_yt'].diff(dim='grid_yt')
    d_grid_xt = reference_grid['grid_xt'].diff(dim='grid_xt')
    # Ensure that the differences are equivalent for all indices to ensure equal spacing
    grid_tolerance = 1e-6
    assert sum(d_grid_xt.diff(dim='grid_xt') < grid_tolerance) == len(d_grid_xt) - 1, 'Grid is irregular along the `grid_xt` axis.'
    assert sum(d_grid_yt.diff(dim='grid_yt') < grid_tolerance) == len(d_grid_yt) - 1, 'Grid is irregular along the `grid_yt` axis.'
    # Get grid spacing along each direction
    dx, dy = d_grid_xt.isel(grid_xt=0).item(), d_grid_yt.isel(grid_yt=0).item()
    
    # Padding on window search for storm coordinate grid
    # This expands the window `padding_factor` grid cells in each direction of the window extent for a given storm timestamp
    padding_factor = 0
    
    # Get extent of longitudes
    minimum_longitude, maximum_longitude = [reference_grid['grid_xt'].min().item() - dx * padding_factor, 
                                            reference_grid['grid_xt'].max().item() + dx * padding_factor]
    # Get extent of latitudes
    minimum_latitude, maximum_latitude = [reference_grid['grid_yt'].min().item() - dy * padding_factor, 
                                          reference_grid['grid_yt'].max().item() + dy * padding_factor]

    # Round values due to weird GFDL GCM output behavior
    minimum_longitude, maximum_longitude = [np.round(minimum_longitude, decimals=4),
                                            np.round(maximum_longitude, decimals=4)]
    minimum_latitude, maximum_latitude = [np.round(minimum_latitude, decimals=4),
                                          np.round(maximum_latitude, decimals=4)]
    dx, dy = [np.round(dx, decimals=4),
              np.round(dy, decimals=4)]

    if diagnostic:
        print(f'[grid_interpolation()] Window extent: longitudes = {(minimum_longitude, maximum_longitude)} and latitudes= {(minimum_latitude, maximum_latitude)}')
        print(f'[grid_interpolation()] dx = {dx}; dy = {dy}')

    # Construct interpolation arrays for interpolating the area grid onto the storm grid
    interpolation_array_x = np.arange(minimum_longitude, maximum_longitude + dx, dx)
    interpolation_array_y = np.arange(minimum_latitude, maximum_latitude + dy, dy)
    
    if diagnostic:
        print(f'[grid_interpolation()] interpolated latitudes: {interpolation_array_y}')
    
    # Perform interpolation of area grid onto the storm grid
    interpolated_working_grid = working_grid.interp(grid_xt=interpolation_array_x).interp(grid_yt=interpolation_array_y)

    return interpolated_working_grid

In [8]:
def get_surface_area(storm_dataset_timestamp: xr.DataArray,
                     diagnostic: bool=False) -> xr.DataArray:

    # Load surface area data from GFDL GCM output
    surface_area = xr.open_dataset('/projects/GEOCLIM/gr7610/tools/AM2.5_atmos_area.nc')['__xarray_dataarray_variable__']

    # Shed null values
    storm_dataset_timestamp = storm_dataset_timestamp.dropna('grid_xt', how='all').dropna('grid_yt', how='all')
    # Get minimum and maximum spatial extent values
    minimum_longitude, maximum_longitude = [storm_dataset_timestamp['grid_xt'].min().item(), 
                                            storm_dataset_timestamp['grid_xt'].max().item()]
    minimum_latitude, maximum_latitude = [storm_dataset_timestamp['grid_yt'].min().item(), 
                                          storm_dataset_timestamp['grid_yt'].max().item()]
    # Round values due to weird GFDL GCM output behavior
    minimum_longitude, maximum_longitude = [np.round(minimum_longitude, decimals=4),
                                            np.round(maximum_longitude, decimals=4)]
    minimum_latitude, maximum_latitude = [np.round(minimum_latitude, decimals=4),
                                          np.round(maximum_latitude, decimals=4)]
    # Get surface area at iterand timestamp
    surface_area_timestamp = surface_area.sel(grid_xt=slice(minimum_longitude, maximum_longitude),
                                              grid_yt=slice(minimum_latitude, maximum_latitude))

    # Interpolate area onto storm coordinates
    interpolated_surface_area_timestamp = grid_interpolation(surface_area_timestamp, storm_dataset_timestamp)

    if diagnostic:
        print('------------------------------------')
        print(f'Window extent: longitudes = {(minimum_longitude, maximum_longitude)} and latitudes= {(minimum_latitude, maximum_latitude)}')
        print(f'[get_surface_area()]: Storm dataset latitudes:\n{storm_dataset_timestamp.grid_yt.values}')
        print(f'[get_surface_area()]: Surface area dataset latitudes:\n{interpolated_surface_area_timestamp.grid_yt.values}')

    # Make sure all longitudes and latitudes are within some tolerance of each other
    assert np.allclose(storm_dataset_timestamp.grid_xt, interpolated_surface_area_timestamp.grid_xt), f'\nData longitudes:\n{storm_dataset_timestamp.grid_xt.values}; \n Surface area longitudes:\n{surface_area_timestamp.grid_xt.values}'
    assert np.allclose(storm_dataset_timestamp.grid_yt, interpolated_surface_area_timestamp.grid_yt), f'\nData latitude:\n{storm_dataset_timestamp.grid_yt.values}; \n Surface area latitudes:\n{surface_area_timestamp.grid_yt.values}'

    return interpolated_surface_area_timestamp

In [22]:
def get_sample_GCM_data(model_name: str,
                        experiment_name: str,
                        field_name: str,
                        year_range: tuple[int, int],
                        sampling_timestamp: pd.Timestamp,
                        longitude: int|float,
                        latitude: int|float,
                        window_size: int,
                        sampling_day_window: int=5,
                        diagnostic: bool=False):

    ''' Method to pull GCM data corresponding to a given TC snapshot. '''

    # Construct field dictionary for postprocessed data loading
    # See `utilities.postprocessed_data_load` for details.
    # Note: this currently only supports single-surface atmospheric data
    field_dictionary = {field_name: {'domain': 'atmos', 'level': None}}
    # Extract month from the iterand timestamp to perform initial climatology filtering
    sampling_year, sampling_month, sampling_day = [sampling_timestamp.year,
                                                   sampling_timestamp.month,
                                                   sampling_timestamp.day,]
    # Load the data
    sample_GCM_data = utilities.postprocessed_data_load(model_name,
                                                        experiment_name,
                                                        field_dictionary,
                                                        year_range,
                                                        data_type='mean_daily',
                                                        month_range=(sampling_month, sampling_month),
                                                        load_full_time=True)[model_name][experiment_name]
    # Get GCM grid spacing
    GCM_dx, GCM_dy = grid_check(sample_GCM_data)
    # Define spatial extent for sample clipping
    grid_xt_extent = slice(longitude - window_size, longitude + window_size)
    grid_yt_extent = slice(latitude - window_size, latitude + window_size)
    # Trim the data spatially
    sample_GCM_data_filtered_space = sample_GCM_data.sortby('grid_yt').sel(grid_xt=grid_xt_extent).sel(grid_yt=grid_yt_extent)
    # Subsample over the time window specified: (iterand timestamp - sampling_day_window) to (iterand_timestamp + sampling_day_window)
    sample_GCM_data_filtered_time = get_time_window(sample_GCM_data_filtered_space, sampling_timestamp, sampling_day_window)
    # Average in time
    sample_GCM_data_filtered = sample_GCM_data_filtered_time.mean(dim='time')
    
    if diagnostic:
        print(f'Storm timestamp center: {longitude}, {latitude}.')
        print(f'GCM grid spacing: dx = {GCM_dx}, dy = {GCM_dy}.')
        print(f'Storm timestamp extent = longitude: {grid_xt_extent}, latitude: {grid_yt_extent}.')
        print(f'GCM extent = longitude: {sample_GCM_data.grid_xt.values}, latitude: {sample_GCM_data.grid_yt.values}.')
        print(f'Filtered GCM extent = longitude: {sample_GCM_data_filtered.grid_xt.values}, latitude: {sample_GCM_data_filtered.grid_yt.values}.')

    return sample_GCM_data_filtered

In [23]:
def get_TC_anomaly_timestamp(model_name: str,
                             experiment_name: str,
                             year_range: tuple[int, int],
                             storm_reanalysis_data: xr.Dataset,
                             field_name: str,
                             sampling_timestamp: cftime.datetime,
                             diagnostic: bool=False):
    
        # Get timestamp
        storm_sample = storm_reanalysis_data.sel(time=sampling_timestamp)
        sampling_timestamp = storm_sample.time.item()
    
        # Get sample TC month and day
        sample_month = sampling_timestamp.month
        sample_day = sampling_timestamp.day
        # Get sample TC center coordinates
        sample_center_longitude = storm_sample['center_lon'].item()
        sample_center_latitude = storm_sample['center_lat'].item()
    
        # Load GCM data according to the given sample
        sample_GCM_data = get_sample_GCM_data(model_name, 
                                              experiment_name,
                                              field_name,
                                              year_range,
                                              sampling_timestamp,
                                              sample_center_longitude,
                                              sample_center_latitude,
                                              window_size=10)
        sample_GCM_data['time'] = sampling_timestamp

        # Interpolate the GCM data to the storm data
        sample_timestamp_area = get_surface_area(storm_sample)
        interpolated_sample_timestamp_area = grid_interpolation(sample_timestamp_area, storm_sample)
        sample_GCM_data = grid_interpolation(sample_GCM_data, storm_sample)
    
        # Get simple anomaly
        TC_climatological_anomaly_timestamp = storm_sample[field_name] - sample_GCM_data[field_name]

        if diagnostic:
            print(f'Storm shape at timestamp {sampling_timestamp}: {storm_sample.grid_xt.shape}, {storm_sample.grid_yt.shape}')
            print(f'GCM shape at timestamp {sampling_timestamp}: {sample_GCM_data.grid_xt.shape}, {sample_GCM_data.grid_yt.shape}')
            print(f'Area shape at timestamp {sampling_timestamp}: {sample_timestamp_area.grid_xt.shape}, {sample_timestamp_area.grid_yt.shape}')
            print(f'Anomaly shape at timestamp {sampling_timestamp}: {TC_climatological_anomaly_timestamp.grid_xt.shape}, {TC_climatological_anomaly_timestamp.grid_yt.shape}')
    
        # Get area-integrated anomaly
        TC_climatological_anomaly_timestamp_integrated = (TC_climatological_anomaly_timestamp * sample_timestamp_area).sum().item()
        
        if diagnostic:
            units = storm_sample[field_name].attrs['units']
            print(f'Area-integrated TC anomaly for timestamp {sampling_timestamp}: {TC_climatological_anomaly_timestamp_integrated:.2e} {units}')
    
        return TC_climatological_anomaly_timestamp, TC_climatological_anomaly_timestamp_integrated

In [12]:
def get_TC_anomaly(model_name: str, 
                   experiment_name: str,
                   field_name: str,
                   year_range: tuple,
                   storm_reanalysis_data: xr.Dataset,
                   parallel: bool=False):
    
    time_integrated_anomaly = 0
    
    storm_GCM_data, area_integrated_TC_anomaly = {}, {}

    partial_TC_anomaly_timestamp = functools.partial(get_TC_anomaly_timestamp,
                                                     model_name,
                                                     experiment_name,
                                                     year_range,
                                                     storm_reanalysis_data,
                                                     field_name)

    sampling_timestamps = storm_reanalysis_data.time.values

    # If chosen, run in parallel using the partial function
    if parallel:
        with Pool() as pool:
            TC_anomaly_pool_outputs = pool.map(partial_TC_anomaly_timestamp, sampling_timestamps)
            TC_anomaly_timestamps, TC_anomaly_timestamps_integrated = zip(*TC_anomaly_pool_outputs)
            pool.close()
    # Else, run serial. Serial is usually better for troubleshooting and debugging.
    else:
        # Initialize container lists for appending outputs for each timestamp
        TC_anomaly_timestamps, TC_anomaly_timestamps_integrated = [], []
        for sampling_timestamp in sampling_timestamps:
            TC_anomaly_timestamp, TC_anomaly_timestamp_integrated = partial_TC_anomaly_timestamp(sampling_timestamp)
            
            TC_anomaly_timestamps.append(TC_anomaly_timestamp)
            TC_anomaly_timestamps_integrated.append(TC_anomaly_timestamp_integrated)
    
    TC_anomaly_dataset = xr.concat(TC_anomaly_timestamps, dim='time').sortby('time')
    
    TC_anomaly_dataset['center_lon'] = storm_reanalysis_data['center_lon']
    TC_anomaly_dataset['center_lat'] = storm_reanalysis_data['center_lat']
    TC_anomaly_dataset['max_wind'] = storm_reanalysis_data['max_wind']
    TC_anomaly_dataset['min_slp'] = storm_reanalysis_data['min_slp']
    TC_anomaly_dataset.attrs['storm_id'] = storm_reanalysis_data.attrs['storm_id']

    time_integrated_anomaly = np.sum(np.array(TC_anomaly_timestamps_integrated))
    
    print('------------------------------------------------------------------------------------')
    units = storm_reanalysis_data[field_name].attrs['units']
    print(f'Time-integrated TC anomaly over TC lifetime: {time_integrated_anomaly:.2e} {units}')
    
    return TC_anomaly_dataset

In [33]:
def animator(model_name: str,
             dataset: xr.Dataset, 
             field_name: str,
             anomaly: bool=True,
             extrema: tuple[int|float, int|float] | None=None):

    ''' Visualization. '''
    
    # Perform animation-specific commands for matplotlib's backend
    from matplotlib import animation
    plt.rcParams['animation.embed_limit'] = 2**28
    plt.rcParams["animation.html"] = "jshtml"
    plt.ioff()
    plt.cla()
    
    # Get colormap and normalization for the dataset entirety
    norm, cmap = visualization.norm_cmap(dataset, field=field_name, white_adjust=anomaly, extrema=extrema)

    # Initialize the figure
    projection, reference_projection = ccrs.PlateCarree(central_longitude=180), ccrs.PlateCarree()
    fig, ax = plt.subplots(figsize=(5, 4), subplot_kw={'projection': projection})
    
    def title_generator(model_name: str,
                        storm_sample: xr.Dataset, 
                        field_name: str, 
                        sampling_timestamp: pd.Timestamp):
        ''' Method to generate a string for titling an axis. '''
        maximum_wind = storm_sample['max_wind']
        minimum_pressure = storm_sample['min_slp']
        long_name, units = visualization.field_properties(field_name)
        ax.set_title('') # xArray bug - this blanks out the title
        ax.set_title(f'{model_name}, storm ID: {storm_sample.attrs['storm_id']}\nMax. wind: {maximum_wind:.2f} m/s; min. SLP: {minimum_pressure:.2f} hPa\n{long_name.capitalize()} [{units}]', loc='left', ha='left', fontsize=10)
    
    def init_func():
        ax.clear()
        
        gridlines = ax.gridlines(ls='--', alpha=0.5)
        gridlines.bottom_labels = True
        ax.coastlines()
        
        cax = ax.inset_axes([1.03, 0, 0.02, 1])
        colorbar = fig.colorbar(matplotlib.cm.ScalarMappable(norm, cmap), cax=cax)
    
    def animate(frame):
        ax.clear()
    
        ''' Begin. '''

        snapshot = dataset.isel(time=frame)
        ax.pcolormesh(snapshot.grid_xt,
                      snapshot.grid_yt,
                      snapshot,
                      norm=norm,
                      cmap=cmap,
                      transform=ccrs.PlateCarree())
        
        # Plot storm center
        ax.scatter(snapshot['center_lon'], snapshot['center_lat'], 
                   marker='x', c='k', s=50, zorder=20, transform=ccrs.PlateCarree())
        
        ax.set_aspect('equal')
        ax.coastlines()

        ''' End. '''
    
        cax = ax.inset_axes([1.03, 0, 0.02, 1])
        cax.clear()
        colorbar = fig.colorbar(matplotlib.cm.ScalarMappable(norm, cmap), cax=cax)
    
        gridlines = ax.gridlines(ls='--', alpha=0.5)
        gridlines.bottom_labels = True
        ax.coastlines()
        
        title_generator(model_name, snapshot, field_name, snapshot.time.values)
        fig.tight_layout()
    
    number_of_frames = len(dataset.time.values) if len(dataset.time.values) < 100 else 12
    fig.tight_layout()
    
    anim = animation.FuncAnimation(fig, animate, frames=number_of_frames, init_func=init_func, interval=100, blit=False)
    %matplotlib inline
    return anim

In [49]:
def save_animation(anim: matplotlib.animation.FuncAnimation,
                   pathname: str,
                   field_name: str,
                   anomaly: bool=True):

    ''' Method to save an animation object for a given TC showing a given field. '''
    
    filename = pathname.split('/')[-1]
    model_name = filename.split('model_name-')[-1].split('.')[0]
    experiment_name = filename.split('model_name-')[-1].split('.')[0]
    animation_storm_ID = filename.split('storm_ID-')[-1].split('.')[0]
    animation_max_wind = filename.split('max_wind-')[-1].split('.')[0]
    animation_min_slp = filename.split('min_slp-')[-1].split('.')[0]
    animation_basin = filename.split('basin-')[-1].split('.')[0]
    anomaly_string = 'anomaly' if anomaly else 'raw'
    animation_filename = f'animation.TC.model-{model_name}.experiment-{experiment_name}.storm_ID-{animation_storm_ID}.max_wind-{animation_max_wind}.min_slp-{animation_min_slp}.basin-{animation_basin}.field_name-{field_name}-{anomaly_string}.mp4'
    animation_dirname = '/projects/GEOCLIM/gr7610/figs/TC-energy_budget/animations'

    import matplotlib.animation as animation
    Writer = animation.writers['ffmpeg']
    writer = Writer(fps=12, metadata=dict(artist='Me'), bitrate=1800)
    anim.save(os.path.join(animation_dirname, animation_filename), writer=writer)

#### Access data

In [45]:
model_name = 'ERA5'
experiment_name = 'reanalysis'
year_range = (2001, 2005)

intensity_parameter = 'min_slp'
intensity_range = (0, 960)

dirname = '/tigress/GEOCLIM/gr7610/analysis/tc_storage/individual_TCs'
filenames = [filename for filename in os.listdir(dirname) if
             model_name in filename and
             experiment_name in filename and
             min(year_range) <= get_filename_year(filename) < max(year_range) and
             min(intensity_range) <= get_filename_intensity(filename, intensity_parameter) < max(intensity_range)]
pathname = os.path.join(dirname, random.choice(filenames))

print(f'Pulling data from filename {pathname}.')
storm_reanalysis_data = xr.open_dataset(pathname, use_cftime=True)
storm_reanalysis_data = utilities.field_correction(model_name, storm_reanalysis_data)

Pulling data from filename /tigress/GEOCLIM/gr7610/analysis/tc_storage/individual_TCs/TC.model-ERA5.experiment-reanalysis.storm_ID-2004-206N20151.max_wind-44.min_slp-935.basin-WP.nc.


#### Pull anomaly for the given TC

In [57]:
field_name = 'olr'

TC_anomaly_dataset = get_TC_anomaly(model_name, 
                                    experiment_name,
                                    field_name,
                                    year_range,
                                    storm_reanalysis_data,
                                    parallel=True)

------------------------------------------------------------------------------------
Time-integrated TC anomaly over TC lifetime: -3.75e+14 W m$^{-2}$


#### Show the animation

In [None]:
anim = animator(model_name,
                TC_anomaly_dataset, 
                field_name=field_name,
                anomaly=True)
anim

In [None]:
save_animation(anim,
               pathname,
               field_name)