# MHWs statistics

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import sys
sys.path.append('/home/b/b382616/notebooks_home/MHW/ocetrac-dask')
import ocetrac_dask
import os
import imageio
from joblib import Parallel, delayed
import matplotlib.dates as mdates
import intake
import dask

# Visualization packages
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.animation import FFMpegWriter
from matplotlib import animation, rc
from IPython.display import HTML


import matplotlib.colors as mcolors
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter

from tempfile import TemporaryDirectory
from getpass import getuser
from pathlib import Path
from dask.distributed import Client, LocalCluster



import itertools
import gc
import subprocess
import re

import warnings
warnings.filterwarnings('ignore')


In [None]:
cluster = LocalCluster(n_workers=32, threads_per_worker=4)
client = Client(cluster)
client

# My dashboard hack
remote_node = subprocess.run(['hostname'], capture_output=True, text=True).stdout.strip().split('.')[0]
port = re.search(r':(\d+)/', client.dashboard_link).group(1)
print(f"Forward with Port = {remote_node}:{port}")

In [3]:
scratch_dir = Path('/scratch') / 'b' / 'b382616' / 'mhws' 

In [4]:
chunk_size = {'time': 100, 'lat': -1, 'lon': -1}

In [None]:
# Load just the Data Array, and immediately convert to int32
mhw_labels = xr.open_zarr(str(scratch_dir / '02_track_dask_newmaskk.zarr'), chunks=chunk_size).labels.astype(np.int32)
mhw_labels

In [None]:
# Don't use unique !

# Make a Data Array so that we can expand out the dimension 'event'
event_ids = xr.DataArray(np.arange(0,mhw_labels.max().compute().values+1, dtype=np.int32), dims='event').chunk({'event': 1000})
event_ids

# Formation 

In [None]:
## Option 1: This will work because we reduce the 3D matrix ASAP
#    Still not great because we have an inefficient loop and store data in a communal list --> Not parallel

formation_events=[]
for i in event_ids:
    binary_event = (mhw_labels == i).sum(dim={'lat','lon'})  # False when not occurring in that time, True if event present at that time
    start_index = binary_event.argmax().compute().values
    formation_events.append(mhw_labels.isel(time=start_index))

combined_formation = xr.concat(formation_events, dim='event')

In [None]:
## Option 2/Better:  
#   Avoid the for loop && keep it lazy
#   Don't force all the event slices to be in the same memory
#   Use binary type which is 64x smaller than float... Means it's going to be ~64x faster 
#   All this allows chunking still in lat/lon/time 👌
# .... Requires many chunks because we expand binary_events to a 4D array 😩

# * This will do all events at the same time, but keeping it lazy
binary_events = (mhw_labels == event_ids).any(dim={'lat','lon'})   # Dimension is now (time x event)
start_indexes = binary_events.argmax(dim='time').compute().values

# Make another Data Array so that we'll keep event when we extract time
start_indexes_da = xr.DataArray(start_indexes, dims='event', coords={'event': event_ids})

# Extract the time slice for each event
formation_events = (mhw_labels == event_ids).isel(time=start_indexes_da)

# NB: this is still lazy // not computed yet


In [7]:
## Option 3/Best:
#  Avoid even having to make the binary_events array and 4D dimension expansion

# * This will do all events at the same time, but keeping it lazy
def clever_binary_events(mhw_labels_chunk, max_ids):
    unique_labels = np.unique(mhw_labels_chunk[mhw_labels_chunk>=0])
    binary_events_chunk = np.zeros(max_ids, dtype=bool)
    binary_events_chunk[unique_labels] = True
    return binary_events_chunk

binary_events = xr.apply_ufunc(
            clever_binary_events,
            mhw_labels,
            event_ids.shape[0],
            input_core_dims=[['lat','lon'],[]],
            output_core_dims=[['event']],
            vectorize=True,
            dask='parallelized',
            output_sizes={'event': event_ids.shape[0]},
            output_dtypes=[bool]
        )

start_indexes = binary_events.argmax(dim='time').compute().values

# Make another Data Array so that we'll keep event when we extract time
start_indexes_da = xr.DataArray(start_indexes, dims='event', coords={'event': event_ids})

# Extract the time slice for each event
formation_events = xr.apply_ufunc(
            lambda x, y: x == y,
            mhw_labels.isel(time=start_indexes_da), # This has dimension (event, lat, lon)
            event_ids,
            input_core_dims=[['lat','lon'],[]],
            output_core_dims=[['lat', 'lon']],
            vectorize=True,
            dask='parallelized',
            output_dtypes=[bool]
        )


In [None]:
# Formation heat map
formation_heatmap = formation_events.sum(dim='event')

formation_heatmap.plot()