In [None]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np
import matplotlib.pyplot as plt
import s3fs
import xarray as xr
import scipy.ndimage
from skimage.measure import regionprops 
from skimage.measure import label as label_np
import dask.array as dsa

In [None]:
endpoint_url = 'https://ncsa.osn.xsede.org'
fs_osn = s3fs.S3FileSystem(anon=True, client_kwargs={'endpoint_url': endpoint_url},)

path = "Pangeo/pangeo-forge/noaa_oisst/v2.1-avhrr.zarr"
ds = xr.open_zarr(fs_osn.get_mapper(path), consolidated=True)
print(ds)

In [None]:
# Define a mask
mask = ds.sst.isel(time=0, zlev=0).fillna(-999)
mask = mask.where(mask==-999, other = 1)
mask = mask.where(mask==1, other = 0)

In [None]:
plt.figure(figsize=(12,12))
ax = plt.axes(projection=ccrs.PlateCarree())
mask.plot(transform=ccrs.PlateCarree(), vmin=0, vmax=1, cmap='Greys_r', extend='max', add_colorbar=False, label=False)
ax.set_title('')
ax.background_patch.set_visible(False)

In [None]:
sst_anom = ds.anom.sel(time=('2014-03-01'), zlev=0).squeeze()
# sst_anom = sst_anom.where(sst_anom >=0.5, other=np.nan)
plt.figure(figsize=(12,12))
ax = plt.axes(projection=ccrs.PlateCarree())
sst_anom.plot(transform=ccrs.PlateCarree(), vmin=-3, vmax=3, cmap='RdBu_r', extend='both', add_colorbar=False, label=False)
ax.set_title('')
ax.coastlines(resolution='110m', color='black', linewidth=1)
ax.add_feature(cfeature.LAND, facecolor='w')
ax.background_patch.set_visible(False)
# ax.outline_patch.set_visible(False)

In [None]:
sst_hot = ds.anom.sel(time=('2014-03-01'), zlev=0).squeeze()
sst_hot = sst_hot.where(sst_hot >=0.5, other=np.nan)

plt.figure(figsize=(12,12))
ax = plt.axes(projection=ccrs.PlateCarree())
sst_hot.plot(transform=ccrs.PlateCarree(), vmin=-3, vmax=3, cmap='RdBu_r', extend='max', add_colorbar=False, label=False)
ax.set_title('')
ax.coastlines(resolution='110m', color='black', linewidth=1)
ax.add_feature(cfeature.LAND, facecolor='w')
ax.background_patch.set_visible(False)
# ax.outline_patch.set_visible(False)

In [None]:
# Convert to Binary
bitmap_binary = sst_hot.where(sst_hot>0, drop=False, other=0)
bitmap_binary = bitmap_binary.where(bitmap_binary==0, drop=False, other=1)

plt.figure(figsize=(12,12))
ax = plt.axes(projection=ccrs.PlateCarree())
bitmap_binary.plot(transform=ccrs.PlateCarree(), vmin=0, vmax=1, cmap='Greys', extend='max', add_colorbar=False, label=False)
ax.set_title('')
ax.coastlines(color='k', linewidth=.5)
# ax.add_feature(cfeature.LAND, facecolor='w')
ax.background_patch.set_visible(False)


In [None]:
# Define structuring element
radius =8
diameter = radius*2
x = np.arange(-radius, radius+1)
x, y = np.meshgrid(x, x)
r = x**2+y**2 
se = r<radius**2

plt.rcParams.update({'font.size': 12})
fig, ax = plt.subplots(1, 1,figsize=(6,6))
plt.pcolormesh(se, cmap='Greys', alpha=0.5)
ax.grid(True, which='both', axis='both', linestyle='-', color='k')
ax.set_xticks(np.arange(0,16,1), minor=True)
ax.set_yticks(np.arange(0,16,1), minor=True)
plt.xlim(0, 17)
plt.ylim(0, 17)
# plt.gca().set_aspect('equal', adjustable='box')
plt.axis('square');


In [None]:
def binary_open_close(bitmap_binary):
    bitmap_binary_padded = np.pad(bitmap_binary,
                                  ((diameter, diameter), (diameter, diameter)),
                                  mode='wrap')
    s1 = scipy.ndimage.binary_closing(bitmap_binary_padded, se, iterations=1)
    s2 = scipy.ndimage.binary_opening(s1, se, iterations=1)
    unpadded= s2[diameter:-diameter, diameter:-diameter]
    return unpadded



sst_hot = ds.anom.sel(zlev=0).squeeze()
sst_hot = sst_hot.where(sst_hot >=0.5, other=np.nan)
bitmap_binary = sst_hot.where(sst_hot>0, drop=False, other=0)
bitmap_binary = bitmap_binary.where(bitmap_binary==0, drop=False, other=1)

if bitmap_binary.chunks:
#     bitmap_binary = bitmap_binary.chunk({'time': -1})
    bitmap_binary = bitmap_binary.chunk({'time': bitmap_binary.sizes['time']})
    
mo_binary = xr.apply_ufunc(binary_open_close, bitmap_binary,
                           input_core_dims=[['time','lat', 'lon']],
                           output_core_dims=[['time','lat', 'lon']],
                           output_dtypes=[bitmap_binary.dtype],
                           vectorize=True,
                           dask='parallelized')

In [None]:
def _apply_mask(binary_images, mask):
    binary_images_with_mask = binary_images.where(mask==1, drop=False, other=0)
    return binary_images_with_mask

mo_binary = _apply_mask(mo_binary, mask)

In [None]:
plt.figure(figsize=(12,12))
ax = plt.axes(projection=ccrs.PlateCarree())
mo_binary.sel(time=('2014-03-01')).plot(transform=ccrs.PlateCarree(), vmin=0, vmax=1, cmap='Greys', extend='max', add_colorbar=False, label=False)
ax.set_title('')
ax.coastlines(color='k', linewidth=.5)
# ax.add_feature(cfeature.LAND, facecolor='w')
ax.background_patch.set_visible(False)

In [None]:
# label
def _label_either(data, **kwargs):
        if isinstance(data, dsa.Array):
            try:
                from dask_image.ndmeasure import label as label_dask
                def label_func(a, **kwargs):
                    ids, num = label_dask(a, **kwargs)
                    return ids
            except ImportError:
                raise ImportError(
                    "Dask_image is required to use this function on Dask arrays. "
                    "Either install dask_image or else call .load() on your data."
                )
        else:
            label_func = label_np
        return label_func(data, **kwargs)
    
def get_labels(binary_images):
        blobs_labels = _label_either(binary_images, background=0)
        return blobs_labels
    
blobs_labels = get_labels(mo_binary)

labels = xr.DataArray(blobs_labels, dims=mo_binary.dims, coords=mo_binary.coords)
labels = labels.where(labels>0, drop=False, other=np.nan)  

# # The labels are repeated each time step, therefore we relabel them to be consecutive
# for i in range(1, labels.shape[0]):
#     labels[i,:,:] = labels[i,:,:].values + labels[i-1,:,:].max().values

In [None]:
from matplotlib.colors import ListedColormap
maxl = int(np.nanmax(labels.values))
cm = ListedColormap(np.random.random(size=(maxl, 3)).tolist())

plt.figure(figsize=(12,12))
ax = plt.axes(projection=ccrs.PlateCarree())
labels.sel(time=('2014-03-01')).plot(cmap=cm, transform=ccrs.PlateCarree(), add_colorbar=False, label=False)
ax.set_title('')
ax.coastlines(color='k', linewidth=.5)
# ax.add_feature(cfeature.LAND, facecolor='w')
ax.background_patch.set_visible(False)

In [None]:
# wrap labels
def _wrap(labels):
        ''' Impose periodic boundary and wrap labels'''
        
        first_column = labels[..., 0]
        last_column = labels[..., -1]

        unique_first = np.unique(first_column[first_column>0])

        # This loop iterates over the unique values in the first column, finds the location of those values in 
        # the first columnm and then uses that index to replace the values in the last column with the first column value
        for i in enumerate(unique_first):
            print(i)
            first = np.where(first_column == i[1])
            last = last_column[first[0]]#, first[1]] # this works only if based on the assumption that the data contains multiple times
            bad_labels = np.unique(last[last>0])
            replace = np.isin(labels, bad_labels)
            labels[replace] = i[1]
        
#         labels_wrapped = np.unique(labels, return_inverse=True)[1].reshape(labels.shape)

        # recalculate the total number of labels 
        N = np.max(labels)

        return labels, N

labels_wrapped, N = _wrap(np.array(labels))

In [None]:
labels_xr = xr.DataArray(labels_wrapped, dims=labels.dims, coords=labels.coords)


In [None]:
# from matplotlib.colors import ListedColormap
# maxl = int(np.nanmax(labels_wrapped))
# cm = ListedColormap(np.random.random(size=(maxl, 3)).tolist())

plt.figure(figsize=(12,12))
ax = plt.axes(projection=ccrs.PlateCarree())
labels_xr.plot(cmap=cm, transform=ccrs.PlateCarree(), add_colorbar=False, label=False)
ax.set_title('')
ax.coastlines(color='k', linewidth=.5)
# ax.add_feature(cfeature.LAND, facecolor='w')
ax.background_patch.set_visible(False)



In [None]:
min_size_quartile = 0.75

# Calculate Area of each object and keep objects larger than threshold
props = regionprops(labels_xr.astype('int'))
labelprops = [p.label for p in props]
labelprops = xr.DataArray(labelprops, dims=['label'], coords={'label': labelprops}) 
area = xr.DataArray([p.area for p in props], dims=['label'], coords={'label': labelprops})  # Number of pixels of the region.
min_area = np.percentile(area, min_size_quartile*100)
print('minimum area: ', min_area) 
keep_labels = labelprops.where(area>=min_area, drop=True)
keep_where = np.isin(labels_xr, keep_labels)
out_labels = xr.DataArray(np.where(keep_where==False, np.nan, labels_xr), dims=mo_binary.dims, coords=mo_binary.coords)

# Convert images to binary. All positive values == 1, otherwise == 0
binary_labels = out_labels.where(out_labels==0, drop=False, other=1)

In [None]:
plt.figure(figsize=(12,12))
ax = plt.axes(projection=ccrs.PlateCarree())
out_labels.plot(cmap=cm, transform=ccrs.PlateCarree(), add_colorbar=False, label=False)
ax.set_title('')
ax.coastlines(color='k', linewidth=.5)
# ax.add_feature(cfeature.LAND, facecolor='w')
ax.background_patch.set_visible(False)
