In [None]:
import xarray as xr
import numpy as np

import glob
import psutil
import threading
import time
import os
import gc

from scipy import ndimage

import dask
from dask.diagnostics import ProgressBar

import copernicusmarine
from copernicusmarine import get

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import cartopy.crs as ccrs
import cartopy.feature as cfeature

Load the raw, percentile, and climatological data for a given location (the script is functional, but work is in progress)!

In [None]:
# Make sure to replace all instances of YOUR_DIRECTORY with your directory filepath!

In [None]:
## LOAD THE CLIMATOLOGICAL MEANS, PERCENTILE THRESHOLDS, AND RAW DATA FOR A GIVEN LOCATION!

# Function to correct some early saved datasets that had doy rather than the normalized_doy coordinate
def add_ds_or_correct(path_arg):
    # Open the dataset from its path
    ds = xr.open_zarr(path_arg).thetao

    # Correct for any datasets without normalized_doy but doy instead
    if 'doy' in ds.coords:
        ds = ds.rename({'doy': 'normalized_doy'}).expand_dims('normalized_doy')

    return ds

# Function to grab datasets from a specified directory
def grab_data(directory, folder_filename_arg, is_raw=False):
    # Get the ordered paths...
    if is_raw:
        paths = glob.glob(f'{directory}/daily_data*.zarr')
    else:
        paths = glob.glob(f'{directory}/{folder_filename_arg}_*.zarr')
    paths.sort()

    # Load the datasets properly
    datasets = [add_ds_or_correct(path) for path in paths]

    # Merge, rechunk, and return them
    if is_raw:
        full_ds = xr.concat(datasets, dim="time").sortby("time")
        full_ds = full_ds.assign_coords(normalized_doy=('time', normalize_dayofyear(full_ds.time).data))
    else:
        full_ds = xr.concat(datasets, dim="normalized_doy").sortby("normalized_doy")
    return full_ds
    
# Obtain the max and min values of the passed dataset
def get_max_and_mins(ds):
    min_lat = ds.latitude.min().item()
    max_lat = ds.latitude.max().item()
    min_lon = ds.longitude.min().item()
    max_lon = ds.longitude.max().item()
    return min_lat, max_lat, min_lon, max_lon
            
# Function to normalize the time values by the unique day of year (doy) coordinates
def normalize_dayofyear(time_coord):
    doy = time_coord.dt.dayofyear
    is_leap = time_coord.dt.is_leap_year

    # This code ensures March 1 is always day 61, regardless of leap year
    normalized_doy = xr.where(
        (~is_leap) & (doy >= 60),  # If it is a non-leap year, doy 60 is March 1. If we have March 1 or later,
        doy + 1,                          # then we push forward March 1 and/or the later days by 1 day.
        doy                               # Otherwise, we keep original for leap years and Jan-Feb 28.
    )

    return normalized_doy

def extra_filter(ds): 
    if folder_filename == "Atlantic":
        if (min(ds.longitude.values)==-101):
            ds = ds.sel(longitude=slice(-101, -14.001))
        if (min(ds.latitude.values)<0):
            max_val = max(ds.latitude.values)
            ds = ds.sel(latitude=slice(-0.75, max_val))
    return ds
    
def show_map(ds_input, title, date, chosen_depth, chosen_doy, is_raw):
    if is_raw:
        ds = ds_input.sel(time=date, depth=chosen_depth, method='nearest')
    else:
        ds = ds_input.sel(normalized_doy=chosen_doy, depth=chosen_depth, method='nearest')
        
    projection_choice = ccrs.Mercator()
    fig, ax = plt.subplots(figsize=(10, 6), 
                           subplot_kw={'projection': projection_choice})
    im = ax.pcolormesh(ds.longitude, ds.latitude, ds,
                       transform=ccrs.PlateCarree(),
                       cmap='RdYlBu_r')
    ax.set_extent([0, 360, -30, 90], crs=ccrs.PlateCarree())
    ax.coastlines()
    ax.gridlines(draw_labels=True)
    cbar = plt.colorbar(im, ax=ax, shrink=0.7)
    cbar.set_label(f'{title}', rotation=270, labelpad=15)
    ax.set_title(f'{title} {date}', fontsize=14)
    plt.show()
    plt.close(fig)
    
    
## PICK THE LOCATION HERE!
# Regions to process thresholds for
region_dict_list = [
    {"Atlantic": [
       "Center", 
        "Right",
        "Top",          
    ]},
  #  {"Pacific": ["Center", 
                 #"Left"
   #             ]},
    #{"Mid": ["Mid",
        #"All",
    #]}
]

# Dictionary of chunk configs (set these to divide their coordinates' values by a small number less than 400 that leaves no remainder)
best_chunk_configs = {
    "Atlantic": {'depth': -1, 'normalized_doy':61, 'latitude': 109, 'longitude': 242},
}

best_raw_chunk_configs = {
    "Atlantic": {'depth': -1, 'time': 24, 'latitude': 109, 'longitude': 242},
}

percentile = 90
show_map = True


for folder in region_dict_list:
    print("Now creating combined climatological datasets for a particular region, which have the same dimensions...\n")
    
    for folder_filename, sub_folders in folder.items():
        print(f"Current main region: {folder_filename}\n")
        
        # Initializing...
        optimal_chunking = None
        
        obs_datasets = []
        thresh_datasets = []
        
        # Code to concatenate:
        for sub_folder_filename in sub_folders:
            print(f"Current custom subregion: {sub_folder_filename}\n")
            
            # Your filepath here; this is my setup.
            id_path = f"{folder_filename}_{sub_folder_filename}"
            thresh_data_directory = f'YOUR_DIRECTORY/Thresh{percentile}th/{folder_filename}/{id_path}'
                
            # Grab the percentile threshold data
            thresh_ds = grab_data(thresh_data_directory, folder_filename, False)
            
            # Grab the raw ("observed") data
            obs_data_directory = f'YOUR_DIRECTORY/Data/{folder_filename}/{id_path}'
            obs_ds = grab_data(obs_data_directory, folder_filename, True)
            
            # Filter the observed data
            thresh_min_lat, thresh_max_lat, thresh_min_lon, thresh_max_lon = get_max_and_mins(thresh_ds)
            obs_ds = obs_ds.sel(latitude=slice(thresh_min_lat, thresh_max_lat),
                                longitude=slice(thresh_min_lon, thresh_max_lon))
            
            # Filter the threshold data
            obs_min_lat, obs_max_lat, obs_min_lon, obs_max_lon = get_max_and_mins(obs_ds)
            thresh_ds = thresh_ds.sel(latitude=slice(obs_min_lat, obs_max_lat),
                                      longitude=slice(obs_min_lon, obs_max_lon))

            thresh_ds = extra_filter(thresh_ds)
            obs_ds = extra_filter(obs_ds)
            
            # Appending datasets
            obs_datasets.append(obs_ds)
            thresh_datasets.append(thresh_ds)
        
        print("--------------------------------------------------------------------------------------------------\n\n")
        
        ## Combining stored datasets for each folder...
        print("Folder:", folder_filename, '\n')
        
        # Setting up chunks
        optimal_chunks = best_chunk_configs[folder_filename]
        raw_optimal_chunks = best_raw_chunk_configs[folder_filename]
        
        # Observed datasets
        if len(obs_datasets) > 1: 
            # If we have many datasets appended to the list, we combine them
            obs_combined_ds = xr.combine_by_coords(obs_datasets, compat='no_conflicts')
        # If there is just one dataset appended to the list, we ignore the list
        else:
            obs_combined_ds = obs_ds
        obs_combined_ds = obs_combined_ds.thetao
        
        if len(thresh_datasets) > 1: 
            thresh_combined_ds = xr.combine_by_coords(thresh_datasets, compat='no_conflicts')
        # If there is just one dataset appended to the list, we ignore the list
        else:
            thresh_combined_ds = thresh_ds
        thresh_combined_ds = thresh_combined_ds.thetao
        
        # Getting the max and mins of the dataset for the climatology slicing
        obs_comb_min_lat, obs_comb_max_lat, obs_comb_min_lon, obs_comb_max_lon = get_max_and_mins(obs_combined_ds)
            
        ## Clim dataset
        # Directory path setup for accessing my climatological means data
        main_folder = folder_filename
        if id_path == "Mid_Mid":
            main_folder = "Mid_Mid"
        elif id_path == "Mid_All":
            main_folder = "Mid_All"
            
        clim_data_directory = f'YOUR_DIRECTORY/Clim/Full/{main_folder}'
        clim_ds = grab_data(clim_data_directory, folder_filename, False)
        
        globals()[f'{folder_filename}_clim'] = clim_ds.sel(latitude=slice(obs_comb_min_lat, obs_comb_max_lat),
                              longitude=slice(obs_comb_min_lon, obs_comb_max_lon)).chunk(optimal_chunks)
        
        #clim_min_lat, clim_max_lat, clim_min_lon, clim_max_lon = get_max_and_mins(clim_ds)
        globals()[f'{folder_filename}_thresh'] = thresh_combined_ds.chunk(optimal_chunks)
        globals()[f'{folder_filename}_obs'] = obs_combined_ds.chunk(raw_optimal_chunks)
        
        #thresh_combined_ds = thresh_combined_ds.sel(latitude=slice(clim_min_lat, clim_max_lat),
         #                     longitude=slice(clim_min_lon, clim_max_lon)).chunk(optimal_chunks)
        
        #obs_combined_ds = obs_combined_ds.sel(latitude=slice(clim_min_lat, clim_max_lat),
         #                     longitude=slice(clim_min_lon, clim_max_lon)).chunk(raw_optimal_chunks)
                
        print("Final observed:\n", globals()[f'{folder_filename}_obs'], '\n')
        print("Final thresh:\n", globals()[f'{folder_filename}_thresh'], '\n')
        print("Final means:\n", globals()[f'{folder_filename}_clim'], '\n\n')
        
        # CURRENTLY UNAVAILABLE
        #if show_map:
            # Showing a quick map of the dataset (to make sure everything came out right!)
         #   show_map(ds_input=globals()[f'{folder_filename}_obs'], title="Observed (deg C)", date="2003-02-28", is_raw=False)
          #  show_map(globals()[f'{folder_filename}_thresh'], "90th Percentile (deg C)")
          #  show_map(globals()[f'{folder_filename}_clim'], "Climatological Mean (deg C)")
        
        print("------------------------------------------------o-------------------------------------------------\n")
            
            # function to calc...

In [None]:
## Thetao depth saving/loading

# First load a FULL dataset with an additional depth value past your target value
depth_filepath = "YOUR_DIRECTORY/Data/Misc/example_file_depths_up_to_NUMBERm.zarr"
depth_final = 319 # the upper bound for the depths + 1 after the target 


# Next, gather all relevant observed depth values
depths_ds = xr.open_zarr(depth_filepath)

all_depth_levels = depths_ds.depth.where(depths_ds.depth <= depth_final, drop=True) # We get an additional depth point after 
print("All depth levels: ", '\n', all_depth_levels, '\n')

## From this output, determine what depths you are particularly interested in reaching towards

In [None]:
## Function to calculate column-averaged severity
def get_max_and_min_years(ds):
    min_yr = ds.time[0].dt.year.item()
    max_yr = ds.time[-1].dt.year.item()
    return min_yr, max_yr


def check_valid_type(data, data_type, data_name, correction_message):
    if not isinstance(data, data_type):
        raise ValueError(f"Invalid {data_name} provided. Please provide a proper {data_type.__name__} {data_name}, or {correction_message}.")


def event_gap_filling(events_series: np.ndarray, minDuration: int = 5, maxGap: int = 2) -> np.ndarray:
    '''
    Gap-filling function for marine heatwave (MHW) events following the Hobday et al. (2016) definition.
    
    Key rules:
    1. Only events of 5+ (or minDuration+) days duration are considered valid MHWs
    2. Gaps of 2 (or maxGap) days or less between any valid events should merge them
    3. When merging occurs, the entire merged period becomes one event
    4. The resulting event must still meet the 5-day minimum duration
    
    Examples from Hobday et al. (2016):
    - [5hot, 2cool, 6hot] → 13-day event (5 + 2 + 6 = 13)
    - [5hot, 1cool, 2hot] → 5-day event (only first 5 days qualify)
    - [2hot, 1cool, 5hot] → 5-day event (only last 5 days qualify)
    - [5hot, 4cool, 6hot] → two separate events (5-day + 6-day)
    '''
    
    # Input validation and conversion
    assert events_series.ndim == 1, f"Expected 1D series, got {events_series.ndim}D"

    # We convert the dataset to boolean if it is not boolean
    if events_series.dtype != bool:
        print("Dataset conversion occurred!", '\n')
        events_series = events_series.astype(bool)

    # If the input series is full of False values, we return an all 0's array
    if not np.any(events_series):
        return np.zeros_like(events_series, dtype=int)
    
    ## First, we find all consecutive True sequences that are ≥ minDuration
    # We initialize our empty list that stores all the valid sequences ≥ the minDuration value
    valid_sequences = []
    i = 0

    # Main loop that examines every position in the series
    while i < len(events_series):
        # If we identify a True value, we check for a potential marine heatwave event
        if events_series[i]:
            # We save the starting point
            start_idx = i
            
            # We advance over a series of consecutive True values (until reaching the end of the array or a False value)
            while i < len(events_series) and events_series[i]:
                i += 1

            # We save the end value and duration of our series accordingly
            end_idx = i - 1 
            duration = end_idx - start_idx + 1 # we add 1 since indeces are 0-based

            # If the minimumDuration is met, we store the start and end indeces as a tuple in our list
            if duration >= minDuration:
                valid_sequences.append((start_idx, end_idx))
                
        else: # If the current position if a False day, we move along
            i += 1

    # If no valid sequences exist, we return a 0-value array
    if not valid_sequences:
        return np.zeros_like(events_series, dtype=int)
    
    ## Next, we group valid event sequences by checking gaps between them
    # We initialize an empty list to store events that should be merged
    event_groups = []
    current_group = [valid_sequences[0]] # sets the current group to check

    # We loop through all valid sequences (starting from the second one), comparing each with the previous
    for i in range(1, len(valid_sequences)):
        
        prev_end = current_group[-1][1] # we extract the end index ([1]) of the last sequence in the current_group ([-1])
        curr_start = valid_sequences[i][0] # we extract the start index ([0]) of the current valid_sequence checked ([i])

        # We calculate the number of days between the current event's start index and the previous event's end index
        gap_length = curr_start - prev_end - 1

        # If the gap is <= the maxGap value, we save/merge (via append) the current valid sequence to/with the current group
        if gap_length <= maxGap:
            current_group.append(valid_sequences[i])
        # If the gap is not <= the maxGap value, we add the untouched current group back to the merged event list and reset the current group
        else:
            event_groups.append(current_group)
            current_group = [valid_sequences[i]]

    # We add the last group to the merged event group list (manually)
    event_groups.append(current_group)
    
    ## Finally, we create our output with appropriate event IDs
    # We begin with an all 0's array
    out = np.zeros_like(events_series, dtype=int)

    # We iterate through all the groups in the merged group list
    for event_id, group in enumerate(event_groups, 1):
        # We find the overall start and end index for each group 
        group_start = group[0][0]
        group_end = group[-1][1]

        # We assign a unique event id for all values from group start to the group end indeces (inclusive)
        out[group_start:group_end + 1] = event_id

    # We return our output
    return out
        
    
# Create a continous time coordinate for a passed dataset
def create_continuous_time_coordinate(ds):
    n_timesteps = len(ds.normalized_doy)
    continuous_time = np.arange(1, n_timesteps + 1)
    
    # Add the continuous coordinate and make it the dominant coordinate
    ds = ds.assign_coords(time_continuous=('normalized_doy', continuous_time))
    ds = ds.swap_dims({'normalized_doy':'time_continuous'}).chunk({'time_continuous':-1})
    return ds


def run_mhw_test(initial_mask_ds, final_mask_ds, lat_val=None, lon_val=None, depth_val=None):
    # Quick check to ensure values were inputted
    if any(val is None for val in (lat_val, lon_val, depth_val)):
        raise ValueError("Missing a key argument in the run_mhw_test function. Please ensure a valid integer value is provided for each value argument!")
    
    print("--------------------------------------------------   TEST START   ---------------------------------------------------\n")
    print(f"RUNNING TEST FOR NEAREST LATITUDE: {lat_val}, LONGITUDE: {lon_val}, AND DEPTH: {depth_val}.\n")
    
    # Grab a subset from the initial_mask_ds
    subset_original = initial_mask_ds.sel(latitude=lat_val, longitude=lon_val, depth=depth_val, method='nearest')
    real_lat = subset_original.latitude.values.item()
    real_lon = subset_original.longitude.values.item()
    real_depth = subset_original.depth.values.item()
    print(f"ACTUAL VALUES DETECTED:\nLATITUDE: {real_lat}, LONGITUDE: {real_lon}, DEPTH: {real_depth}\n")
    
    original_labeled = subset_original.values
    print("Exceed values:\n", original_labeled, '\n\n')

    events_labeled = final_mask_ds.sel(latitude=lat_val, longitude=lon_val, depth=depth_val, method='nearest').values
    print("Event gap-filled values:\n", events_labeled, '\n\n')

    print("----------o-----------0----------o-------------\n")

    length = min(len(original_labeled), len(events_labeled))
    step = 10

    for i in range(0, length, step):
        end_idx = min(i + step, length)
        print(f"Values {i} to {end_idx - 1}:")
        print("Exceed values:", original_labeled[i:end_idx])
        print("Event gap-filled values:", events_labeled[i:end_idx])
        print("-" * 60)
        print(" ")
    print("--------------------------------------------------   END OF TEST   ---------------------------------------------------\n")


stop_monitoring = True # We begin by NOT showing any memory usage
def monitor_memory(interval_minutes=5, log_file=None):
    interval = interval_minutes * 60  
    
    while not stop_monitoring:
        mem = psutil.Process(os.getpid()).memory_info().rss / (1024**3)  # in GB
        print(f" | Memory usage: {mem:.2f} GB | Memory: {psutil.virtual_memory().percent}% used | ")
        
        if log_file:
            with open(log_file, 'a') as f:
                f.write(f"{time.strftime('%Y-%m-%d %H:%M:%S')}: {mem:.2f} GB\n")
        time.sleep(interval)
    
    
# -----------------------------------------------------------------------------------------------------------------------------
def calculate_severity(obs_data, thresh_data, clim_data, 
                       baseline_name, folder_name, sub_folder_name,
                       best_chunks, best_chunks_no_depth, 
                       depths_list, depth_thicknesses,
                       smoothWidth=31, minDaysMhwDuration=5, maxGapDaysInMhw=2, 
                       minutes_per_mem_update=5,
                       show_debug=True,
                       custom_years=False, start_yr=None, end_yr=None,
                       mhw_test=False, mhw_test_lat=None, mhw_test_lon=None, mhw_test_depth=None):

    # Quick check for inputs regarding a marine heatwave test a specific coordinate
    if mhw_test:
        check_valid_type(mhw_test_lat, int, "latitude value", "set mhw_test to False")
        check_valid_type(mhw_test_lon, int, "longitude value", "set mhw_test to False")
        check_valid_type(mhw_test_depth, int, "depth value", "set mhw_test to False")
    
    # Gather the min and max years in the observed data
    min_obs_yr, max_obs_yr = get_max_and_min_years(obs_data)

    # Setting up custom years, or going with the default (min to max years in the observed data)
    if custom_years:
        check_valid_type(start_yr, int, "starting year", "set custom_years to False")
        check_valid_type(end_yr, int, "ending year", "set custom_years to False")
        years = range(start_yr, end_yr + 1)
    else:
        years = range(min_obs_yr, max_obs_yr + 1)
        

    # Beginning of code for severity calculation:
    print("--------------------o-------------------------------o------------------------------o----------------------------")
    print(f'\nBaseline used: {baseline_name}\n') # We only have one baseline
    print(f"The smoothWidth argument was set to: {smoothWidth}.") 
    print(f"This means a rolling window of {smoothWidth} doys centered on each individual doy is used.")
    print(f"For each doy, the rolling window contains the previous {(smoothWidth-1)/2} doys, the center doy, and the next {(smoothWidth-1)/2} doys.\n")

    # Showing the original datasets if show_debug is enabled
    if show_debug: 
        print('Original raw, "observed" data:\n', obs_data, '\n')
        print("Original percentile threshold data:\n", thresh_data, '\n')
        print("Original climatological means data:\n", clim_data, '\n')
        print("-----------------------------------------------------------------------------------------------------------\n")
    
    
    ## Calculating layer thicknesses, and taking the column-average of the severity dataset for marine heatwaves and entire columns
    for target_depth in depths_list:
        # Getting the lower bound from the target list
        max_wanted_depth = depth_thicknesses.sel(depth=target_depth, method='nearest').item()
        print("Inputted Lower-Bound Depth: ", max_wanted_depth)
            
        # Get the positional index of the bound
        depth_coord = depth_thicknesses.get_index('depth').values
        max_depth_index = np.where(depth_coord == max_wanted_depth)[0][0]

        if max_depth_index + 1 < len(depth_coord):
            next_depth = depth_coord[max_depth_index + 1]
            next_value = depth_thicknesses.sel(depth=next_depth).item()
            print("Following Lower-Bound Depth: ", next_value)
        else:
            raise ValueError("Error: Depth bound exceeded! Please provide a target depth list with lower depth values!")

        # Gathering only the values we intend on using (depth values up to 50 meters)
        valid_idx = np.where(depth_thicknesses <= max_wanted_depth)[0]
        valid_depths = depth_thicknesses[valid_idx]
        depth_values = valid_depths.values

        # Creating layer bounds, using depth temperatures as midpoints
        bounds = np.zeros(len(depth_values) + 1)
        bounds[0] = 0 # Surface boundary

        # Midpoints between depths for internal bounds
        bounds[1:-1] = (depth_values[:-1] + depth_values[1:]) / 2

        # The last bound is the midpoint between the last included depth (in depth_values) and first excluded depth (in all_depth_values)
        if len(valid_idx) < len(depth_thicknesses):
            bounds[-1] = (depth_values[-1] + next_value) / 2
            final_max_depth_val = round(bounds[-1], 1)
            print(f"Final (Interpolated) Lower-Bound Depth Used: {final_max_depth_val} meters\n") 

        if show_debug: 
            print("Depth values: ", '\n', depth_values, '\n')
            print("Bounds: ", '\n', bounds, '\n')    
            print("Length of valid depths: ", len(depth_values))
            print("Length of bounds: ", len(bounds), '(Should be 1 more than valid depths, for thickness calculation)\n')

        # Creating a layer thickness data array
        layer_thickness = xr.DataArray(np.diff(bounds), coords={"depth": valid_depths}, dims="depth")
        if show_debug: 
            print("Final layer thicknesses: ", '\n', layer_thickness, '\n')
            print("--------------------------------------------------------------------------------------", '\n')
            
        # Subsetting the datasets by depth
        obs_data = obs_data.sel(depth=slice(0, max_wanted_depth))
        thresh_data = thresh_data.sel(depth=slice(0, max_wanted_depth))
        clim_data = clim_data.sel(depth=slice(0, max_wanted_depth))
        
        # Showing the depth-subsetted datasets if show_debug is enabled
        if show_debug: 
            print('Depth-Subsetted "observed" data:\n', obs_data, '\n')
            print("Depth-Subsetted percentile threshold data:\n", thresh_data, '\n')
            print("Depth-Subsetted climatological means data:\n", clim_data, '\n')
            print("-----------------------------------------------------------------------------------------------------------\n")
    
       
        ## PADDING AND SMOOTHING
        # Padding with smoothWidth (which is more than enough for the desired smoothing with smoothWidth)
        padded_clim = clim_data.pad(normalized_doy=smoothWidth, mode='wrap')
        if show_debug: print("Padded Climatology:\n", padded_clim, '\n')

        padded_thresh = thresh_data.pad(normalized_doy=smoothWidth, mode='wrap')
        if show_debug: print("Padded Threshold:\n", padded_thresh, '\n\n')

        # Smoothing
        clim_smoothed = padded_clim.rolling(normalized_doy=smoothWidth, center=True, min_periods=smoothWidth).mean() 
        if show_debug: print(f"Padded Climatology Smoothed by {smoothWidth} doys:\n", clim_smoothed, '\n')

        clim_smoothed = clim_smoothed.isel(normalized_doy=slice(smoothWidth, -smoothWidth))
        if show_debug: print("Smoothed Climatology (No Pad):\n", clim_smoothed, '\n\n')

        thresh_smoothed = padded_thresh.rolling(normalized_doy=smoothWidth, center=True, min_periods=smoothWidth).mean()
        if show_debug: print(f"Padded Threshold Smoothed by {smoothWidth} doys: ", thresh_smoothed, '\n')

        thresh_smoothed = thresh_smoothed.isel(normalized_doy=slice(smoothWidth, -smoothWidth))
        if show_debug: print("Threshold Smoothed (No Pad): ", thresh_smoothed, '\n')

        # We also save early and late smoothed doy data for later 
        # MAY NEED TO BE IN THE FOR LOOP IF USING A LARGER smoothWidth (if the smoothWidth includes Feb 29, doy 60)
        thresh_start_days = thresh_smoothed.isel(normalized_doy=slice(0, smoothWidth)).copy()
        thresh_end_days = thresh_smoothed.isel(normalized_doy=slice(-smoothWidth, None)).copy()


        global stop_monitoring
        if show_debug:
            stop_monitoring = True # We don't want to start showing memory use.
            print("-----------------------------------------------------------------------------------------------------------\n")
        else:
            # We start monitoring here so that it only runs once
            stop_monitoring = False # We do want to start showing memory use.
            monitor_thread = threading.Thread(target=monitor_memory, kwargs={'interval_minutes': minutes_per_mem_update})
            monitor_thread.daemon = True
            monitor_thread.start()

        shown_once = False # for debugging purposes


        ## CALCULATING ANOMALIES BY YEAR (observed - climatological means) AND DETECTING MHWS
        for year in years:
            # Time subsetting
            obs_year_data = obs_data.sel(time=f'{year}')
            obs_year_data_norm = obs_year_data.swap_dims({'time':'normalized_doy'})
            if show_debug and not shown_once: print(f"Observed data for {year}:\n", obs_year_data_norm, '\n')

            # Aligning the datasets (returns doys common to both)
            obs_aligned, clim_aligned = xr.align(obs_year_data_norm, clim_smoothed, join="inner")
            clim_aligned = clim_aligned.chunk(best_chunks)

            obs_aligned, thresh_aligned = xr.align(obs_year_data_norm, thresh_smoothed, join="inner")
            thresh_aligned = thresh_aligned.chunk(best_chunks)

            if show_debug and not shown_once: 
                print(f"Thresh aligned: ", thresh_aligned, '\n')
                print(f"Clim aligned: ", clim_aligned, '\n')

            # We also save the time variable separately for later, then drop it from the observed
            obs_aligned_time = obs_aligned.swap_dims({'normalized_doy':'time'}).drop_vars('normalized_doy').time
            if show_debug and not shown_once: print("-----------------------------------------------------------------------------------------------------------\n")

            ### DETECTING MHWS ACROSS THE YEARS
            ## Grabbing previous and next year data
            obs_aligned = obs_aligned.swap_dims({'normalized_doy':'time'})
            if show_debug and not shown_once: print(f"Full Year Data from the Current Year ({year}):\n", obs_aligned, '\n')

            # Important objects for padding
            prev_year_data = None
            next_year_data = None

            # We grab the previous year's data if it is not the first or final year in the dataset
            if year > min_obs_yr:
                prev_year = year - 1
                prev_year_full = obs_data.sel(time=f'{prev_year}')

                # If it exists, we grab the previous year's final period of length smoothWidth (up to 366)
                if len(prev_year_full) > 0:
                    prev_year_data = prev_year_full.isel(time=slice(-smoothWidth, None))
                    if show_debug and not shown_once: print(f"End of the Year Data from the Previous Year ({prev_year}):\n", prev_year_data, '\n')

            # We grab the next year's data if it is not the first or final year in the dataset
            if year < max_obs_yr:
                next_year = year + 1
                next_year_full = obs_data.sel(time=f'{next_year}')

                if len(next_year_full) > 0:
                    next_year_data = next_year_full.isel(time=slice(0, smoothWidth))
                    if show_debug and not shown_once: print(f"Beginning of the Year Data from the Next Year ({next_year}):\n", next_year_data, '\n')

            ## Merging the previous and next year datasets, if they are available
            data_pieces = [] 

            # We first append the previous year data to our array, if it exists
            if prev_year_data is not None: data_pieces.append(prev_year_data)

            # Next, we append the current year data to our array, if it exists
            data_pieces.append(obs_aligned)

            # Lastly, we append the next year data to our array, if it exists
            if next_year_data is not None: data_pieces.append(next_year_data)

            if show_debug and not shown_once: print("Data pieces (prev + full current + next):\n", data_pieces, '\n\n')

            # Merging the datasets
            obs_year_data_extended = xr.concat(data_pieces, dim='time')
            obs_year_data_extended = obs_year_data_extended.sortby('time')
            if show_debug and not shown_once: print("Initial merged observed data:\n", obs_year_data_extended, '\n')

            # We switch back to our desired normalized doy time dimension
            obs_year_data_extended = obs_year_data_extended.swap_dims({'time':'normalized_doy'}).drop_vars('time')
            obs_year_data_extended = obs_year_data_extended.chunk(best_chunks)
            if show_debug and not shown_once: 
                print("Final merged observed data: ", obs_year_data_extended, '\n')
                print("-----------------------------------------------------------------------------------------------------------\n")

            ## Threshold padding
            data_pieces_thresh = []

            # We append previous year data if it exists first
            if prev_year_data is not None:
                data_pieces_thresh.append(thresh_end_days)
                if show_debug and not shown_once: print('"Previous" Year Threshold Data:\n', thresh_end_days, '\n')  # doys with smoothwidth length up to 366

            # We then append current year data
            data_pieces_thresh.append(thresh_aligned)

            # We lastly append next year data if it exists
            if next_year_data is not None:
                data_pieces_thresh.append(thresh_start_days)
                if show_debug and not shown_once: print('"Next" Year Threshold Data:\n', thresh_start_days, '\n') # doy 1 up to the end of smoothwidth doys

            if show_debug and not shown_once: print("Data pieces thresh (prev + full current + next):\n", data_pieces_thresh, '\n\n')

            # Since we cannot sort by time, order matters most here!
            thresh_data_extended = xr.concat(data_pieces_thresh, dim='normalized_doy')
            thresh_data_extended = thresh_data_extended.chunk(best_chunks)
            if show_debug and not shown_once: 
                print("Final Merged Threshold Data:\n", thresh_data_extended, '\n')
                print("-----------------------------------------------------------------------------------------------------------\n")


            ## Creating continuous coordinates (but only if the obs and thresh datasets have matching days of the year)
            if obs_year_data_extended.normalized_doy.equals(thresh_data_extended.normalized_doy):
                if show_debug and not shown_once: 
                    print("Aligned observed and padded thresh datasets' normalized_doys match!\n")
                    print("Creating a new, continous coordinate for each for marine heatwave detection...\n")

                obs_time_aligned = create_continuous_time_coordinate(obs_year_data_extended)
                thresh_time_aligned = create_continuous_time_coordinate(thresh_data_extended)

                if show_debug and not shown_once:
                    print("Continuous Observed: ", '\n', obs_time_aligned, '\n\n',
                          "Continuous Thresh: ", '\n', thresh_time_aligned, '\n')
            else:
                print("ERROR DETECTED! PRINTING RELEVANT OUTPUT:\n")
                print(obs_year_data_extended.normalized_doy, '\n\n', thresh_data_extended.normalized_doy, '\n')
                print(obs_year_data_extended.normalized_doy.values, '\n',
                     thresh_data_extended.normalized_doy.values)

                raise ValueError("Unexpected Error Detected: Coordinate arrays of observed and threshold datasets do not match exactly; please debug!")

            if show_debug and not shown_once: print("-----------------------------------------------------------------------------------------------------------\n")

            ## Exceedence (marine heatwave) mask labeling for a series of connected events

            # To be able to use the scipy.ndimage.label function properly to detect events that last over year-end boundaries, 
            # it is crucial to create a continuous coordinate first (done in previous section) and run it through the custom function.

            # Exceedence bool mask creation (for marine heatwaves)
            exceed = (obs_time_aligned > thresh_time_aligned) # initial check if the observed temps are greater than their 90th percentiles
            exceed = exceed.fillna(False) # Replace NaNs with False   

            # Applying mhw event series labeling over the lat-lon grid
            events_gaps_filled = xr.apply_ufunc(
                event_gap_filling,
                exceed,
                input_core_dims=[['time_continuous']],
                output_core_dims=[['time_continuous']],
                vectorize=True,
                dask='parallelized',
                output_dtypes=[int],
                kwargs={'maxGap': maxGapDaysInMhw, 'minDuration':minDaysMhwDuration}, 
                dask_gufunc_kwargs={"output_sizes": {"time_continuous": exceed.sizes["time_continuous"]}},
            )

            events_gaps_filled = events_gaps_filled.transpose('time_continuous','latitude','longitude', 'depth')
            events_gaps_filled = events_gaps_filled.swap_dims({'time_continuous':'normalized_doy'})

            if show_debug and not shown_once: print("Event-labelled, gap-filled series: ", events_gaps_filled, '\n')

            # We run the MHW test with padded data, to ensure we correctly identify MHWs within the full current year period
            if mhw_test:
                run_mhw_test(exceed, events_gaps_filled, lat_val=mhw_test_lat, lon_val=mhw_test_lon, depth_val=mhw_test_depth)

            # Now, we remove the padding
            prev_yr_slice = smoothWidth if (prev_year_data is not None) else 0
            next_yr_slice = -smoothWidth if (next_year_data is not None) else None

            events_gaps_filled_unpadded = events_gaps_filled.isel(normalized_doy=slice(prev_yr_slice, next_yr_slice))
            if show_debug and not shown_once: print("Marine Heatwave Events Bool Dataset, Unpadded: ", events_gaps_filled_unpadded, '\n')

            mhw_bool_final = events_gaps_filled_unpadded.drop_vars('time_continuous').chunk({'normalized_doy':-1})
            if show_debug and not shown_once: 
                print("Final Marine Heatwave Events Bool Dataset: ", mhw_bool_final, '\n')
                print("-----------------------------------------------------------------------------------------------------------\n")

            # Rechunk the observed dataset
            obs_aligned = obs_aligned.swap_dims({'time':'normalized_doy'}).drop_vars('time').chunk(best_chunks)
            if show_debug and not shown_once: 
                print("Aligned observed data:\n", obs_aligned, '\n')
                print("Aligned thresh data:\n", thresh_aligned, '\n')
                print("Aligned clim data:\n", clim_aligned, '\n\n')

            ## SEVERITY DENOM (PC90 - CLIM)
            sev_denom = thresh_aligned - clim_aligned
            depth_wtd_sev_denom = sev_denom.weighted(layer_thickness).mean(dim='depth')
            
            mhw_sev_denom = xr.where(mhw_bool_final, thresh_aligned - clim_aligned, np.nan).chunk({'normalized_doy': -1})
            depth_wtd_mhw_sev_denom = mhw_sev_denom.weighted(layer_thickness).mean(dim='depth')
            
            ## SEVERITY NUM (OBS - CLIM)  
            sev_num = obs_aligned - clim_aligned
            depth_wtd_sev_num = sev_num.weighted(layer_thickness).mean(dim='depth')
            
            mhw_sev_num = xr.where(mhw_bool_final, obs_aligned - clim_aligned, np.nan).chunk({'normalized_doy': -1})
            depth_wtd_mhw_sev_num = mhw_sev_num.weighted(layer_thickness).mean(dim='depth')

            ## SEVERITY (NUM/DENOM)
            severity = depth_wtd_sev_num / depth_wtd_sev_denom
            severity = severity.chunk(best_chunks_no_depth)
            
            mhw_severity = depth_wtd_mhw_sev_num / depth_wtd_mhw_sev_denom
            mhw_severity = mhw_severity.chunk(best_chunks_no_depth)
            
            if show_debug: 
                print("Depth-Weighted Severity [(OBS - CLIM) / (PC90 - CLIM)]:\n", severity, '\n')
                print("Depth-Weighted Severity (only for marine heatwaves):\n", mhw_severity, '\n')
                print("-----------------------------------------------------------------------------------------------------------\n")

            # Reassigning time
            if len(obs_aligned_time) == severity.sizes['normalized_doy']:
                # We rename the mhw labeled events dataset for merging
                mhw_sev_final = mhw_severity.rename(f'severity_mhw_{final_max_depth_val}')
                sev_final = severity.rename(f'severity_{final_max_depth_val}')

                # Merge the datasets
                full_sev = xr.merge([sev_final, mhw_sev_final])

                # If there is perfect alignment with our severity dataset, we assign the original time coordinate back to it!
                severity_final = full_sev.assign_coords(time=('normalized_doy', obs_aligned_time.values))
                severity_final = severity_final.swap_dims({'normalized_doy': 'time'}).drop_vars('normalized_doy').chunk({'time':-1})
            else:
                # If there's a dimension mismatch, we raise an error!
                raise ValueError(f"Time/normalized_doy dimension mismatch detected for computation of severity for the {year} year!")
            
            if show_debug and not shown_once:
                shown_once = True
                
                
            ## SAVING
            print(f"Starting Save of {baseline_name} {year} Severity Dataset (Averaged Up to a Depth of {final_max_depth_val} m)...\n")

            # for filepath setup
            def safe_float_str(x):
                s = str(x)
                return s.replace('.', '-') # to adjust the target radius for non whole integer values
            
            # Setting up the destination filepath
            id_path = f"{folder_name}_{sub_folder_name}"
            filename = f"{id_path}_severity_300m_PC90th_{year}_depth_0_to_{safe_float_str(final_max_depth_val)}_{baseline_name}.zarr"
            sev_filepath = f'YOUR_DIRECTORY/Severity_300m/{folder_name}/{id_path}/{filename}'
            print("File Path: ", sev_filepath, '\n')

            # Last-minute chunking to not hit the maximum buffer size for chunks
            if len(severity_final.time.values) == 365:
                severity_final = severity_final.chunk({'time':73})
            elif len(severity_final.time.values) == 366:
                severity_final = severity_final.chunk({'time':61})
            print(f"{year} Severity:\n", severity_final, '\n')

            if show_debug:
                raise ValueError("You have reached the end of the debug. To begin saving the data, set show_debug to False!")
            else:
                with ProgressBar():
                    severity_final.to_zarr(sev_filepath, mode='w')
                
                # Freeing up memory
                severity_final.close()
                del severity_final
                gc.collect()
                
                print("Done saving! Moving along!", '\n')

            print("-----------------------------------------------------------------------------------------------------------\n")
        
        print(f"Finished saving the column-averaged severity for depths up to {final_max_depth_val} m for all years! Moving on to the next depth!\n")
        print("-----------------------------------------------------------------------------------------------------------\n")
    
    print("--------------------o-------------------------------o------------------------------o----------------------------", '\n')
    print("Finished saving data for all years!")
    stop_monitoring = True
       
        
best_chunk_configs = {
    "Atlantic": {'depth': -1, 'normalized_doy':-1, 'latitude': 109, 'longitude': 242},
}

post_thickness_configs = {
    "Atlantic": {'normalized_doy':-1, 'latitude': 109, 'longitude': 242},
}

# if you get warnings about large chunks, use smaller chunks in previous step were data combined, or adjust best chunks accordingly

# pick your MINIMUM target depth; the column-average of the severity you obtain lies between the minimum depth and the next depth
target_depths_list = [#266, 
                      #47, 
                      #92, 
                      130, # for example, 130 is the lower bound, the max (next) bound was ~156, and the final depth it goes up to is ~143
                      
] 

optimal_chunks = best_chunk_configs["Atlantic"]
sev_optimal_chunks = post_thickness_configs["Atlantic"]

calculate_severity(globals()[f'{folder_filename}_obs'], 
                   globals()[f'{folder_filename}_thresh'].drop_vars('quantile'), 
                   globals()[f'{folder_filename}_clim'], 
                   "Baseline9322", "Atlantic", "Full", 
                   optimal_chunks, sev_optimal_chunks, 
                   depths_list=target_depths_list, depth_thicknesses = all_depth_levels,
                   minutes_per_mem_update=20, show_debug=False, 
                   custom_years=True, start_yr=2000, end_yr=2002,
                   mhw_test=False)
# For testing the detection to ensure everything is working as follows, you can add code like these for a single point:
# mhw_test=True, mhw_test_lat=25, mhw_test_lon=-40, mhw_test_depth=50)  

stop_monitoring = True

In [None]:
# Manual stop to the memory printing function above
stop_monitoring = True