In [None]:
import xarray as xr
import numpy as np
import pandas as pd

import glob
import psutil
import threading
import time
import os
import gc

from scipy import ndimage

import dask
from dask.diagnostics import ProgressBar

import copernicusmarine
from copernicusmarine import get

from datetime import datetime
from datetime import date

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.animation as animation
import cartopy.crs as ccrs
import cartopy.feature as cfeature

In [None]:
## Important script-wide constants
start_val_arg = 1 # your percentile and mean datasets' starting day of the year (doy) value (both must have this)
end_val_arg = 366 # your percentile and mean datasets' ending day of the year (doy) value (both must have this)
baseline_name_arg = "Baseline9322" # your chosen baseline identifier for the 30-year baseline
current_percentile = 90 # your chosen percentile (as a percent, not under 1.0)
custom_id_choice = "fgd" # your custom identifier used in both mean and percentile datasets 
my_root_directory = "" # Should be your root directory, from which you access data from and save data to
# See sst_climatology_and_percentile_calculator.ipynb for your used values (or consult your stored datasets)
dataset_id = "" # The name of the dataset folder, which should contain your raw data, means, and percentiles calculed with the climatology script


# Data directories that can be further modified by you may be identified by searching for 
# "NOTE: POTENTIAL DIRECTORY TWEAKING HERE" in this script.

In [None]:
## Loading SST data ---------------------------------------------------------------------------------
def normalize_dayofyear(time_coord):
    doy = time_coord.dt.dayofyear
    is_leap = time_coord.dt.is_leap_year

    # This code ensures March 1 is always day 61, regardless of leap year
    normalized_doy = xr.where(
        (~is_leap) & (doy >= 60),  # If it is a non-leap year, doy 60 is March 1. If we have March 1 or later,
        doy + 1,                          # then we push forward March 1 and/or the later days by 1 day.
        doy                               # Otherwise, we keep original for leap years and Jan-Feb 28.
    )
    return normalized_doy

# Your filepath here; this is my setup.
raw_data_directory = f'{my_root_directory}/{dataset_id}/Data/sst.day.mean.*.nc'

# Sometimes this will result in a crash the first time. Wait a bit and run the cell again...
ds = xr.open_mfdataset(
    raw_data_directory,         # Glob pattern
    parallel=True,              # Enable parallel file opening 
    chunks='auto',              # Let Dask choose optimal chunking
    combine='by_coords',        # Merge based on coordinate values
    engine='netcdf4'            # Specify the engine (it may crash otherwise; if it still does, restart the kernel)
)

optimal_chunking = {'lat': 210, 'lon': 160} # for the raw data
optimal_chunks = {'normalized_doy':-1, 'lat': 210, 'lon': 160} # for the to-be-created severity data
full_ds = ds.sel(lat=slice(-15, 90), lon=slice(0, 360)).sst
full_ds = full_ds.chunk(optimal_chunking)
Full_obs = full_ds.assign_coords(
        normalized_doy=('time', normalize_dayofyear(full_ds.time).data))
print("Observed data:\n", Full_obs, '\n')

## Loading constants
folder_name_arg = "Full"

custom_name_arg = "fgd"
if custom_name_arg != "":
    id = f"{custom_name_arg}_"
else:
    id =""
    
## Loading climatological means data 
file_name = f"{folder_name_arg}_{id}sst_clim_subset_{start_val_arg}_to_{end_val_arg}_{baseline_name_arg}.zarr"
filepath = f'{my_root_directory}/{dataset_id}/Clim/{baseline_name_arg}/{file_name}' # NOTE: POTENTIAL DIRECTORY TWEAKING HERE
Full_means = xr.open_zarr(filepath).sst
print("Clim. Means:\n", Full_means, '\n')

## Loading percentile threshold data
file_name = f"{folder_name_arg}_{id}sst_thresh_subset_{start_val_arg}_to_{end_val_arg}_{baseline_name_arg}.zarr"
filepath = f'{my_root_directory}/{dataset_id}/Thresh{current_percentile}th/{baseline_name_arg}/{file_name}' # NOTE: POTENTIAL DIRECTORY TWEAKING HERE
Full_thresh = xr.open_zarr(filepath).sst
print("Percentile Thresholds:\n", Full_thresh, '\n')

In [None]:
## Function to calculate column-averaged severity
def get_max_and_min_years(ds):
    min_yr = ds.time[0].dt.year.item()
    max_yr = ds.time[-1].dt.year.item()
    return min_yr, max_yr


def check_valid_type(data, data_type, data_name, correction_message):
    if not isinstance(data, data_type):
        raise ValueError(f"Invalid {data_name} provided. Please provide a proper {data_type.__name__} {data_name}, or {correction_message}.")


def event_gap_filling(events_series: np.ndarray, minDuration: int = 5, maxGap: int = 2) -> np.ndarray:
    '''
    Gap-filling function for marine heatwave (MHW) events following the Hobday et al. (2016) definition.
    
    Key rules:
    1. Only events of 5+ (or minDuration+) days duration are considered valid MHWs
    2. Gaps of 2 (or maxGap) days or less between any valid events should merge them
    3. When merging occurs, the entire merged period becomes one event
    4. The resulting event must still meet the 5-day minimum duration
    
    Examples from Hobday et al. (2016):
    - [5hot, 2cool, 6hot] → 13-day event (5 + 2 + 6 = 13)
    - [5hot, 1cool, 2hot] → 5-day event (only first 5 days qualify)
    - [2hot, 1cool, 5hot] → 5-day event (only last 5 days qualify)
    - [5hot, 4cool, 6hot] → two separate events (5-day + 6-day)
    '''
    
    # Input validation and conversion
    assert events_series.ndim == 1, f"Expected 1D series, got {events_series.ndim}D"

    # We convert the dataset to boolean if it is not boolean
    if events_series.dtype != bool:
        print("Dataset conversion occurred!", '\n')
        events_series = events_series.astype(bool)

    # If the input series is full of False values, we return an all 0's array
    if not np.any(events_series):
        return np.zeros_like(events_series, dtype=int)
    
    ## First, we find all consecutive True sequences that are ≥ minDuration
    # We initialize our empty list that stores all the valid sequences ≥ the minDuration value
    valid_sequences = []
    i = 0

    # Main loop that examines every position in the series
    while i < len(events_series):
        # If we identify a True value, we check for a potential marine heatwave event
        if events_series[i]:
            # We save the starting point
            start_idx = i
            
            # We advance over a series of consecutive True values (until reaching the end of the array or a False value)
            while i < len(events_series) and events_series[i]:
                i += 1

            # We save the end value and duration of our series accordingly
            end_idx = i - 1 
            duration = end_idx - start_idx + 1 # we add 1 since indeces are 0-based

            # If the minimumDuration is met, we store the start and end indeces as a tuple in our list
            if duration >= minDuration:
                valid_sequences.append((start_idx, end_idx))
                
        else: # If the current position if a False day, we move along
            i += 1

    # If no valid sequences exist, we return a 0-value array
    if not valid_sequences:
        return np.zeros_like(events_series, dtype=int)
    
    ## Next, we group valid event sequences by checking gaps between them
    # We initialize an empty list to store events that should be merged
    event_groups = []
    current_group = [valid_sequences[0]] # sets the current group to check

    # We loop through all valid sequences (starting from the second one), comparing each with the previous
    for i in range(1, len(valid_sequences)):
        
        prev_end = current_group[-1][1] # we extract the end index ([1]) of the last sequence in the current_group ([-1])
        curr_start = valid_sequences[i][0] # we extract the start index ([0]) of the current valid_sequence checked ([i])

        # We calculate the number of days between the current event's start index and the previous event's end index
        gap_length = curr_start - prev_end - 1

        # If the gap is <= the maxGap value, we save/merge (via append) the current valid sequence to/with the current group
        if gap_length <= maxGap:
            current_group.append(valid_sequences[i])
        # If the gap is not <= the maxGap value, we add the untouched current group back to the merged event list and reset the current group
        else:
            event_groups.append(current_group)
            current_group = [valid_sequences[i]]

    # We add the last group to the merged event group list (manually)
    event_groups.append(current_group)
    
    ## Finally, we create our output with appropriate event IDs
    # We begin with an all 0's array
    out = np.zeros_like(events_series, dtype=int)

    # We iterate through all the groups in the merged group list
    for event_id, group in enumerate(event_groups, 1):
        # We find the overall start and end index for each group 
        group_start = group[0][0]
        group_end = group[-1][1]

        # We assign a unique event id for all values from group start to the group end indeces (inclusive)
        out[group_start:group_end + 1] = event_id

    # We return our output
    return out
        
    
# Create a continous time coordinate for a passed dataset
def create_continuous_time_coordinate(ds):
    n_timesteps = len(ds.normalized_doy)
    continuous_time = np.arange(1, n_timesteps + 1)
    
    # Add the continuous coordinate and make it the dominant coordinate
    ds = ds.assign_coords(time_continuous=('normalized_doy', continuous_time))
    ds = ds.swap_dims({'normalized_doy':'time_continuous'}).chunk({'time_continuous':-1})
    return ds


def run_mhw_test(initial_mask_ds, final_mask_ds, lat_val=None, lon_val=None):
    # Quick check to ensure values were inputted
    if any(val is None for val in (lat_val, lon_val)):
        raise ValueError("Missing a key argument in the run_mhw_test function. Please ensure a valid integer value is provided for each value argument!")
    
    print("--------------------------------------------------   TEST START   ---------------------------------------------------\n")
    print(f"RUNNING TEST FOR NEAREST LATITUDE: {lat_val}, LONGITUDE: {lon_val}.\n")
    
    # Grab a subset from the initial_mask_ds
    subset_original = initial_mask_ds.sel(lat=lat_val, lon=lon_val, method='nearest')
    real_lat = subset_original.lat.values.item()
    real_lon = subset_original.lon.values.item()
    print(f"ACTUAL VALUES DETECTED:\nLATITUDE: {real_lat}, LONGITUDE: {real_lon}\n")
    
    original_labeled = subset_original.values
    print("Exceed values:\n", original_labeled, '\n\n')

    events_labeled = final_mask_ds.sel(lat=lat_val, lon=lon_val, method='nearest').values
    print("Event gap-filled values:\n", events_labeled, '\n\n')

    print("----------o-----------0----------o-------------\n")

    length = min(len(original_labeled), len(events_labeled))
    step = 10

    for i in range(0, length, step):
        end_idx = min(i + step, length)
        print(f"Values {i} to {end_idx - 1}:")
        print("Exceed values:", original_labeled[i:end_idx])
        print("Event gap-filled values:", events_labeled[i:end_idx])
        print("-" * 60)
        print(" ")
    print("--------------------------------------------------   END OF TEST   ---------------------------------------------------\n")


stop_monitoring = True # We begin by NOT showing any memory usage
def monitor_memory(interval_minutes=5, log_file=None):
    interval = interval_minutes * 60  
    
    while not stop_monitoring:
        mem = psutil.Process(os.getpid()).memory_info().rss / (1024**3)  # in GB
        print(f" | Memory usage: {mem:.2f} GB | Memory: {psutil.virtual_memory().percent}% used | ")
        
        if log_file:
            with open(log_file, 'a') as f:
                f.write(f"{time.strftime('%Y-%m-%d %H:%M:%S')}: {mem:.2f} GB\n")
        time.sleep(interval)
    
    
# -----------------------------------------------------------------------------------------------------------------------------
def calculate_severity(obs_data, thresh_data, clim_data, 
                       baseline_name, folder_name, custom_id,
                       best_chunks,
                       smoothWidth=31, percentile=90,
                       minDaysMhwDuration=5, maxGapDaysInMhw=2, 
                       minutes_per_mem_update=5,
                       show_debug=True,
                       custom_years=False, start_yr=None, end_yr=None,
                       mhw_test=False, mhw_test_lat=None, mhw_test_lon=None):

    # Quick check for inputs regarding a marine heatwave test a specific coordinate
    if mhw_test:
        check_valid_type(mhw_test_lat, int, "latitude value", "set mhw_test to False")
        check_valid_type(mhw_test_lon, int, "longitude value", "set mhw_test to False")
    
    # Gather the min and max years in the observed data
    min_obs_yr, max_obs_yr = get_max_and_min_years(obs_data)

    # Setting up custom years, or going with the default (min to max years in the observed data)
    if custom_years:
        check_valid_type(start_yr, int, "starting year", "set custom_years to False")
        check_valid_type(end_yr, int, "ending year", "set custom_years to False")
        years = range(start_yr, end_yr + 1)
    else:
        years = range(min_obs_yr, max_obs_yr + 1)

    # Set up a variable to control whether a second smoothing (of the climatological means/percentiles) is applied
    apply_smoothing = False if smoothWidth == 0 else True
    shown_once = False # for debugging purposes

    # Set up the id to save the dataset with, if desired
    if custom_id != "":
        id = f"{custom_id}_"
    else:
        id =""
    
    # Beginning of code for severity calculation:
    print("--------------------o-------------------------------o------------------------------o----------------------------")
    print(f'\nBaseline used: {baseline_name}\n') # We only have one baseline
    print(f"The smoothWidth argument was set to: {smoothWidth}.")
    
    if apply_smoothing:
        print(f"This means a rolling window of {smoothWidth} doys centered on each individual doy is used.")
        print(f"For each doy, the rolling window contains the previous {(smoothWidth-1)/2} doys, the center doy, and the next {(smoothWidth-1)/2} doys.\n")
    else:
        print("No smoothing was applied; using the threshold and mean datasets as is!\n")

    # Showing the original datasets if show_debug is enabled
    if show_debug: 
        print('Original raw, "observed" data:\n', obs_data, '\n')
        print("Original percentile threshold data:\n", thresh_data, '\n')
        print("Original climatological means data:\n", clim_data, '\n')
        print("-----------------------------------------------------------------------------------------------------------\n")

    if apply_smoothing:
        ## PADDING AND SMOOTHING
        # Padding with smoothWidth (which is more than enough for the desired smoothing with smoothWidth)
        padded_clim = clim_data.pad(normalized_doy=smoothWidth, mode='wrap')
        if show_debug: print("Padded Means:\n", padded_clim, '\n')
    
        padded_thresh = thresh_data.pad(normalized_doy=smoothWidth, mode='wrap')
        if show_debug: print("Padded Threshold:\n", padded_thresh, '\n\n')
    
        # Smoothing
        clim_smoothed = padded_clim.rolling(normalized_doy=smoothWidth, center=True, min_periods=smoothWidth).mean() 
        if show_debug: print(f"Padded Means Smoothed by {smoothWidth} doys:\n", clim_smoothed, '\n')
    
        clim_processed = clim_smoothed.isel(normalized_doy=slice(smoothWidth, -smoothWidth))
        if show_debug: print("Smoothed Means (No Pad):\n", clim_processed, '\n\n')
    
        thresh_smoothed = padded_thresh.rolling(normalized_doy=smoothWidth, center=True, min_periods=smoothWidth).mean()
        if show_debug: print(f"Padded Threshold Smoothed by {smoothWidth} doys: ", thresh_smoothed, '\n')
    
        thresh_processed = thresh_smoothed.isel(normalized_doy=slice(smoothWidth, -smoothWidth))
        if show_debug: print("Threshold Smoothed (No Pad): ", thresh_processed, '\n')

        
        # We also save early and late smoothed doy data for late/early-year mhw detection later
        # MAY NEED TO SET THESE ELSEWHERE IF USING A LARGER smoothWidth (if the smoothWidth extends past 31 to include the interpolated Feb 29)
        doys_to_gather = smoothWidth # can be adjusted to your needs
        thresh_prev_yr_days = thresh_smoothed.isel(normalized_doy=slice(0, doys_to_gather)).copy()
        thresh_next_yr_days = thresh_smoothed.isel(normalized_doy=slice(-doys_to_gather, None)).copy()
        
    elif not apply_smoothing:
        # We still save early and late raw threshold doy data for late/early-year mhw detection later
        doys_to_gather = 15 # additional days to append to the dataset for year-round mhw detection (should be greater than 6)
        thresh_prev_yr_days = thresh_data.isel(normalized_doy=slice(0, doys_to_gather)).copy()
        thresh_next_yr_days = thresh_data.isel(normalized_doy=slice(-doys_to_gather, None)).copy()
        
        # We also run the prior raw datasets through these new variables
        thresh_processed = thresh_data
        clim_processed = clim_data

    if show_debug:
        print(f'\n"Next Year" Threshold Dataset Doys:\n{thresh_next_yr_days}\n')
        print(f'"Previous Year" Threshold Dataset Doys:\n{thresh_prev_yr_days}\n')

    
    # Quick monitoring initiation
    global stop_monitoring
    if show_debug:
        stop_monitoring = True # We don't want to start showing memory use.
        print("-----------------------------------------------------------------------------------------------------------\n")
    else:
        # We start monitoring here so that it only runs once
        stop_monitoring = False # We do want to start showing memory use.
        monitor_thread = threading.Thread(target=monitor_memory, kwargs={'interval_minutes': minutes_per_mem_update})
        monitor_thread.daemon = True
        monitor_thread.start()

    ## CALCULATING ANOMALIES BY YEAR (observed - climatological means) AND DETECTING MHWS
    for year in years:
        # Time subsetting
        obs_year_data = obs_data.sel(time=f'{year}')
        obs_year_data_norm = obs_year_data.swap_dims({'time':'normalized_doy'})
        if show_debug and not shown_once: print(f"Observed data for {year}:\n", obs_year_data_norm, '\n')

        # Aligning the datasets (returns doys common to both)
        obs_aligned, clim_aligned = xr.align(obs_year_data_norm, clim_processed, join="inner")
        clim_aligned = clim_aligned.chunk(best_chunks)

        obs_aligned, thresh_aligned = xr.align(obs_year_data_norm, thresh_processed, join="inner")
        thresh_aligned = thresh_aligned.chunk(best_chunks)

        if show_debug and not shown_once: 
            print(f"Thresh aligned: ", thresh_aligned, '\n')
            print(f"Clim aligned: ", clim_aligned, '\n')

        # We also save the time variable separately for later, then drop it from the observed
        obs_aligned_time = obs_aligned.swap_dims({'normalized_doy':'time'}).drop_vars('normalized_doy').time
        if show_debug and not shown_once: print("-----------------------------------------------------------------------------------------------------------\n")

        ### DETECTING MHWS ACROSS THE YEARS
        ## Grabbing previous and next year data
        obs_aligned = obs_aligned.swap_dims({'normalized_doy':'time'})
        if show_debug and not shown_once: print(f"Full Year Data from the Current Year ({year}):\n", obs_aligned, '\n')

        # Important objects for padding
        prev_year_data = None
        next_year_data = None

        # We grab the previous year's data if it is not the first or final year in the dataset
        if year > min_obs_yr:
            prev_year = year - 1
            prev_year_full = obs_data.sel(time=f'{prev_year}')

            # If it exists, we grab the previous year's final period of length smoothWidth (up to 366)
            if len(prev_year_full) > 0:
                prev_year_data = prev_year_full.isel(time=slice(-doys_to_gather, None))
                if show_debug and not shown_once: print(f"End of the Year Data from the Previous Year ({prev_year}):\n", prev_year_data, '\n')

        # We grab the next year's data if it is not the first or final year in the dataset
        if year < max_obs_yr:
            next_year = year + 1
            next_year_full = obs_data.sel(time=f'{next_year}')

            if len(next_year_full) > 0:
                next_year_data = next_year_full.isel(time=slice(0, doys_to_gather))
                if show_debug and not shown_once: print(f"Beginning of the Year Data from the Next Year ({next_year}):\n", next_year_data, '\n')

        ## Merging the previous and next year datasets, if they are available
        data_pieces = [] 

        # We first append the previous year data to our array, if it exists
        if prev_year_data is not None: data_pieces.append(prev_year_data)

        # Next, we append the current year data to our array, if it exists
        data_pieces.append(obs_aligned)

        # Lastly, we append the next year data to our array, if it exists
        if next_year_data is not None: data_pieces.append(next_year_data)

        if show_debug and not shown_once: print("Data pieces (prev + full current + next):\n", data_pieces, '\n\n')

        # Merging the datasets
        obs_year_data_extended = xr.concat(data_pieces, dim='time')
        obs_year_data_extended = obs_year_data_extended.sortby('time')
        if show_debug and not shown_once: print("Initial merged observed data:\n", obs_year_data_extended, '\n')

        # We switch back to our desired normalized doy time dimension
        obs_year_data_extended = obs_year_data_extended.swap_dims({'time':'normalized_doy'}).drop_vars('time')
        obs_year_data_extended = obs_year_data_extended.chunk(best_chunks)
        if show_debug and not shown_once: 
            print("Final merged observed data: ", obs_year_data_extended, '\n')
            print("-----------------------------------------------------------------------------------------------------------\n")

        ## Threshold padding
        data_pieces_thresh = []

        # We append previous year data if it exists first
        if prev_year_data is not None:
            data_pieces_thresh.append(thresh_prev_yr_days)
            if show_debug and not shown_once: print('"Previous" Year Threshold Data:\n', thresh_prev_yr_days, '\n')  # doys with doys_to_gather length up to 366

        # We then append current year data
        data_pieces_thresh.append(thresh_aligned)

        # We lastly append next year data if it exists
        if next_year_data is not None:
            data_pieces_thresh.append(thresh_next_yr_days)
            if show_debug and not shown_once: print('"Next" Year Threshold Data:\n', thresh_next_yr_days, '\n') # doy 1 up to the end of doys_to_gather doys

        if show_debug and not shown_once: print("Data pieces thresh (prev + full current + next):\n", data_pieces_thresh, '\n\n')

        # Since we cannot sort by time, order matters most here!
        thresh_data_extended = xr.concat(data_pieces_thresh, dim='normalized_doy')
        thresh_data_extended = thresh_data_extended.chunk(best_chunks)
        if show_debug and not shown_once: 
            print("Final Merged Threshold Data:\n", thresh_data_extended, '\n')
            print("-----------------------------------------------------------------------------------------------------------\n")


        ## Creating continuous coordinates (but only if the obs and thresh datasets have matching days of the year)
        if obs_year_data_extended.normalized_doy.equals(thresh_data_extended.normalized_doy):
            if show_debug and not shown_once: 
                print("Aligned observed and padded thresh datasets' normalized_doys match!\n")
                print("Creating a new, continous coordinate for each for marine heatwave detection...\n")

            obs_time_aligned = create_continuous_time_coordinate(obs_year_data_extended)
            thresh_time_aligned = create_continuous_time_coordinate(thresh_data_extended)

            if show_debug and not shown_once:
                print("Continuous Observed: ", '\n', obs_time_aligned, '\n\n',
                      "Continuous Thresh: ", '\n', thresh_time_aligned, '\n')
        else:
            print("ERROR DETECTED! PRINTING RELEVANT OUTPUT:\n")
            print(obs_year_data_extended.normalized_doy, '\n\n', thresh_data_extended.normalized_doy, '\n')
            print(obs_year_data_extended.normalized_doy.values, '\n',
                 thresh_data_extended.normalized_doy.values)

            raise ValueError("Unexpected Error Detected: Coordinate arrays of observed and threshold datasets do not match exactly; please debug!")

        if show_debug and not shown_once: print("-----------------------------------------------------------------------------------------------------------\n")

        ## Exceedence (marine heatwave) mask labeling for a series of connected events

        # To be able to use the scipy.ndimage.label function properly to detect events that last over year-end boundaries, 
        # it is crucial to create a continuous coordinate first (done in previous section) and run it through the custom function.

        # Exceedence bool mask creation (for marine heatwaves)
        exceed = (obs_time_aligned > thresh_time_aligned) # initial check if the observed temps are greater than their 90th percentiles
        exceed = exceed.fillna(False) # Replace NaNs with False   

        # Applying mhw event series labeling over the lat-lon grid
        events_gaps_filled = xr.apply_ufunc(
            event_gap_filling,
            exceed,
            input_core_dims=[['time_continuous']],
            output_core_dims=[['time_continuous']],
            vectorize=True,
            dask='parallelized',
            output_dtypes=[int],
            kwargs={'maxGap': maxGapDaysInMhw, 'minDuration':minDaysMhwDuration}, 
            dask_gufunc_kwargs={"output_sizes": {"time_continuous": exceed.sizes["time_continuous"]}},
        )

        events_gaps_filled = events_gaps_filled.transpose('time_continuous','lat','lon')
        events_gaps_filled = events_gaps_filled.swap_dims({'time_continuous':'normalized_doy'})

        if show_debug and not shown_once: print("Event-labelled, gap-filled series: ", events_gaps_filled, '\n')

        # We run the MHW test with padded data, to ensure we correctly identify MHWs within the full current year period
        if mhw_test:
            run_mhw_test(exceed, events_gaps_filled, lat_val=mhw_test_lat, lon_val=mhw_test_lon)

        # Now, we remove the padding
        prev_yr_slice = doys_to_gather if (prev_year_data is not None) else 0
        next_yr_slice = -doys_to_gather if (next_year_data is not None) else None

        events_gaps_filled_unpadded = events_gaps_filled.isel(normalized_doy=slice(prev_yr_slice, next_yr_slice))
        if show_debug and not shown_once: print("Marine Heatwave Events Bool Dataset, Unpadded: ", events_gaps_filled_unpadded, '\n')

        mhw_bool_final = events_gaps_filled_unpadded.drop_vars('time_continuous').chunk({'normalized_doy':-1})
        if show_debug and not shown_once: 
            print("Final Marine Heatwave Events Bool Dataset: ", mhw_bool_final, '\n')
            print("-----------------------------------------------------------------------------------------------------------\n")

        # Rechunk the observed dataset
        obs_aligned = obs_aligned.swap_dims({'time':'normalized_doy'}).drop_vars('time').chunk(best_chunks)
        if show_debug and not shown_once: 
            print("Aligned observed data:\n", obs_aligned, '\n')
            print("Aligned thresh data:\n", thresh_aligned, '\n')
            print("Aligned clim data:\n", clim_aligned, '\n\n')

        ## SEVERITY DENOM (PC90 - CLIM)
        sev_denom = thresh_aligned - clim_aligned
        mhw_sev_denom = xr.where(mhw_bool_final, thresh_aligned - clim_aligned, np.nan).chunk({'normalized_doy': -1})
        
        ## SEVERITY NUM (OBS - CLIM)  
        sev_num = obs_aligned - clim_aligned
        mhw_sev_num = xr.where(mhw_bool_final, obs_aligned - clim_aligned, np.nan).chunk({'normalized_doy': -1})

        ## SEVERITY (NUM/DENOM)
        severity = sev_num / sev_denom
        severity = severity.chunk(best_chunks)
        
        mhw_severity = mhw_sev_num / mhw_sev_denom
        mhw_severity = mhw_severity.chunk(best_chunks)
        
        if show_debug: 
            print("Severity [(OBS - CLIM) / (PC90 - CLIM)]:\n", severity, '\n')
            print("Severity (only for marine heatwaves):\n", mhw_severity, '\n')
            print("-----------------------------------------------------------------------------------------------------------\n")

        # Reassigning time
        if len(obs_aligned_time) == severity.sizes['normalized_doy']:
            # We rename the mhw labeled events dataset for merging
            mhw_sev_final = mhw_severity.rename(f'severity_mhw')
            sev_final = severity.rename(f'severity')

            # Merge the datasets
            full_sev = xr.merge([sev_final, mhw_sev_final])

            # If there is perfect alignment with our severity dataset, we assign the original time coordinate back to it!
            severity_final = full_sev.assign_coords(time=('normalized_doy', obs_aligned_time.values))
            severity_final = severity_final.swap_dims({'normalized_doy': 'time'}).drop_vars('normalized_doy').chunk({'time':-1})
        else:
            # If there's a dimension mismatch, we raise an error!
            raise ValueError(f"Time/normalized_doy dimension mismatch detected for computation of severity for the {year} year!")
        
        if show_debug and not shown_once:
            shown_once = True
            
            
        ## SAVING
        print(f"Starting Save of {baseline_name} {year} Severity Dataset...\n")

        # Setting up the destination filepath
        filename = f"{folder_name}_severity_{id}SST_PC{percentile}th_{year}_smoothWidth{smoothWidth}_{baseline_name}.zarr"
        sev_filepath = f'{my_root_directory}/{dataset_id}/Severity_{percentile}th/{baseline_name}/{filename}' # NOTE: POTENTIAL DIRECTORY TWEAKING HERE
        print("File Path: ", sev_filepath, '\n')

        # Last-minute chunking to not hit the maximum buffer size for chunks
        if len(severity_final.time.values) == 365:
            severity_final = severity_final.chunk({'time':73})
        elif len(severity_final.time.values) == 366:
            severity_final = severity_final.chunk({'time':61})
        print(f"{year} Severity:\n", severity_final, '\n')

        if show_debug:
            raise ValueError("You have reached the end of the debug. To begin saving the data, set show_debug to False!")
        else:
            with ProgressBar():
                severity_final.to_zarr(sev_filepath, mode='w')
            
            # Freeing up memory
            severity_final.close()
            del severity_final
            gc.collect()
            
            print("Done saving! Moving along!", '\n')
            print("-----------------------------------------------------------------------------------------------------------\n")

    stop_monitoring = True
    print("Finished saving data for all years!")
    print("--------------------o-------------------------------o------------------------------o----------------------------", '\n')  


calculate_severity(obs_data=Full_obs, thresh_data=Full_thresh.drop_vars('quantile'), clim_data=Full_means, 
                   baseline_name=baseline_name_arg, folder_name=folder_name_arg, custom_id=custom_id_choice,
                   best_chunks=optimal_chunks, 
                   smoothWidth=31, percentile=current_percentile,
                   minDaysMhwDuration=5, maxGapDaysInMhw=2,
                   minutes_per_mem_update=5, show_debug=False, 
                   custom_years=True, start_yr=1982, end_yr=2024,
                   mhw_test=False)

stop_monitoring = True

In [None]:
## Create an animation for MHW severity or severity (regardless of MHWs)

## Function to create an animation that shows the mean and percentile latitude and longitude maps for the full 1 - 366 period. 
def check_severities_with_an_animation(baseline_name_arg, folder_name_arg, 
                                       year, mhw_only = False,
                                        custom_output_filename=None, 
                                        percentile=None):

    ## First, gather the appropriate stored dataset filepath
    if mhw_only:
        data_var_to_access = "severity_mhw"
        data_type = "MHW Severity"
    else:
        data_var_to_access = "severity"
        data_type = "Severity"

    # NOTE: POTENTIAL DIRECTORY TWEAKING HERE
    path = f'{my_root_directory}/{dataset_id}/Severity_{percentile}th/{folder_name_arg}_severity_SST_PC{percentile}th_{year}_{baseline_name_arg}.zarr'  
    
    ## Fill a dictionary where all present days are matched with their corresponding filepath
    times_dict = {}
    
    # Open and check what days are in this file
    try:
        ds = xr.open_zarr(path)[data_var_to_access]
    except:
        raise ValueError(f"No file found at: {path}.")
    
    available_times = ds['time'].values
    available_times = pd.to_datetime(available_times)
   
    # Quick fix for my personal, early datasets
    lon = ds.lon.values
    lat = ds.lat.values

    # Set up the plot using the values in the first day
    setup_data = ds.isel(time=0).values

    ## Initialize the plot    
    fig, ax = plt.subplots(figsize=(14, 6), 
                           subplot_kw={'projection': ccrs.Mercator()})
    
    pcm = ax.pcolormesh(
        lon, lat, setup_data,
        cmap='RdYlBu_r',
        vmin=-1, vmax=5,
        transform=ccrs.PlateCarree(),
    )
    cbar = plt.colorbar(pcm, ax=ax, label='Temperature Severity')  # Adjust label as needed
    
    ax.set_extent([0, 360, -5, 90], crs=ccrs.PlateCarree())
    ax.add_feature(cfeature.LAND, color='lightgray')
    ax.add_feature(cfeature.COASTLINE, linewidth=0.8)
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    
    title = ax.set_title('')
    title_base = '(Relative to 1993-2022)' if baseline_name_arg == 'Baseline9322' else f'({baseline_name_arg})'
    
    ## Animation function
    def animate(i):
        time = available_times[i]
        frame_data = ds.isel(time=i).values
        
        # Update the plot
        pcm.set_array(frame_data.ravel())

        # Update the title appropriately
        time_str = pd.Timestamp(time).strftime('%Y-%m-%d')
        title.set_text(f'{data_type} for {time_str}\n{title_base}')
        
        return pcm, title
    
    
    ## Create the resulting animation
    type_message = "marine heatwave severity" if mhw_only else "severity"
    print(f"Began animation for the {type_message} in the {year} {folder_name_arg} datasets!")
    
    chosen_times = len(available_times)
    
    anim = animation.FuncAnimation(
        fig, animate,
        frames=chosen_times,
        interval=200,
        blit=False,
        repeat=True
    )
    
    writer = animation.PillowWriter(fps=2)
    
    if custom_output_filename == None:
        output_filename = f"{folder_name_arg}_{data_var_to_access}_{baseline_name_arg}_{year}.gif"
    else:
        output_filename = custom_output_filename
    print(f"Saving animation at: {output_filename}") 
    
    
    ## Save the resulting animation
    def print_frame_progress(current_frame, total_frames):
        print(f"\r → Date (Frame) Processed: {current_frame + 1}/{total_frames}", end='', flush=True)

    #with ProgressBar():
    anim.save(output_filename, writer=writer, dpi=100,
              progress_callback=print_frame_progress)
    
    plt.tight_layout()
    plt.close(fig)
    ds.close()
    print(f"\nAnimation finished and saved!\n")
    
    return anim

# ----------------------------------------------------------------------------------------------------------------------------

check_severities_with_an_animation("Baseline9322", "Full", year = 2024, mhw_only = False, custom_output_filename=None, percentile=90)
check_severities_with_an_animation("Baseline9322", "Full", year = 2024, mhw_only = True, custom_output_filename=None, percentile=90)
check_severities_with_an_animation("Baseline9322", "Full", year = 1993, mhw_only = False, custom_output_filename=None, percentile=90)
check_severities_with_an_animation("Baseline9322", "Full", year = 1993, mhw_only = True, custom_output_filename=None, percentile=90)
check_severities_with_an_animation("Baseline9322", "Full", year = 1982, mhw_only = False, custom_output_filename=None, percentile=90)
check_severities_with_an_animation("Baseline9322", "Full", year = 1982, mhw_only = True, custom_output_filename=None, percentile=90)

MHW Detection Validation (coming soon)