In [None]:
import netCDF4 as nc
import scipy 
import os
import re
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.colors as mcolors
from matplotlib.gridspec import GridSpec
# Load marineHeatWaves definition module
import marineHeatWaves as mhw
import dask
from xmhw.xmhw import threshold, detect
from datetime import date
import cftime
import statsmodels.api as sm
from statsmodels.formula.api import ols
import statsmodels.formula.api as smf
from scipy.optimize import curve_fit
import plotly.express as px
import plotly.graph_objects as go
import hashlib
import matplotlib.colors as mcolors
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import cartopy.feature as cfeature
import matplotlib.ticker as ticker
from matplotlib.colors import LinearSegmentedColormap
import glob
import math
import seaborn as sns
import cmcrameri.cm as cmc  # cmcrameri colormaps
import matplotlib.cm as cmm  # matplotlib colormaps
from pypalettes import load_cmap
from cmap import Colormap
from mpl_toolkits.axes_grid1 import make_axes_locatable

### urchin data preparation

In [None]:
df_urchin = pd.read_csv('ep_m2_inverts_australia.csv', low_memory=False)
df_urchin = df_urchin[df_urchin.taxon.isin(['Centrostephanus rodgersii'])]  
# adjust time format
df_urchin['survey_date'] = pd.to_datetime(df_urchin['survey_date'], format='%Y-%m-%d')  #%d/%m/%Y
df_urchin['survey_year'] = df_urchin['survey_date'].dt.year
df_urchin = df_urchin[df_urchin.survey_year > 2012]
df_urchin = df_urchin[(df_urchin.longitude < 155)].reset_index() 
df_urchin.head()

In [None]:
df_urchin['number'] = df_urchin.total/50
df_urchin = df_urchin[['survey_id', 'location', 'site_code', 'site_name', 'latitude', 'longitude', 'survey_year', 'program', 'block', 'taxon', 'size_class', 'total', 'number']]
df_urchin

In [None]:
# average the latest decade
# get sum along the size_class for the same id, site, block
df_urchin_size = df_urchin.groupby(["survey_id", "location", "site_name", "latitude", "longitude", 'survey_year', "taxon", "block"]).sum(
                                        numeric_only = True).reset_index(
                                        level=["survey_id", "location", "site_name", "latitude", "longitude", 'survey_year', "taxon", "block"])
# average the block
df_urchin_block = df_urchin_size.groupby(["survey_id","location", "site_name", "latitude", "longitude", 'survey_year', "taxon"]).mean(
                                        numeric_only = True).reset_index(
                                        level=["survey_id", "location", "site_name", "latitude", "longitude", 'survey_year', "taxon"])
# sum the taxon, even though two species distributed in differnt places 
df_urchin_tax = df_urchin_block.groupby(["survey_id", "location","site_name", "latitude", "longitude", 'survey_year']).sum(numeric_only = True).reset_index(
                                        level=["survey_id", "location","site_name", "latitude", "longitude", 'survey_year'])
# average each id
df_urchin_id = df_urchin_tax.groupby(["location","site_name", "latitude", "longitude", 'survey_year']).mean(numeric_only = True).reset_index(
                                      level=["location","site_name", "latitude", "longitude", 'survey_year'])
# df_urchin_tax
# average the latest decade
df_urchin_decade = df_urchin_id.groupby(["location","site_name", "latitude", "longitude"]).mean(numeric_only = True).reset_index(
                                         level=['location', 'site_name', 'latitude', 'longitude'])
df_urchin_decade

### urchin number histogram along lat

In [None]:
n_bins = 20
# Create an array from -25 to -45 for binning
latitude_bins = np.linspace(-45, -25, n_bins + 1)
# Create bins using the predefined latitude array
df_urchin_decade['bin'], bins = pd.cut(df_urchin_decade['latitude'], bins=latitude_bins, retbins=True, right=False)
all_bins = pd.DataFrame({'bin': pd.IntervalIndex.from_breaks(bins, closed='left')})
# all_bins

In [None]:
# urchin threshold for kelp collapse under three mhw conditions
threshold = [1.7, 2.4, 3.1]  # 1.7, 2.95, 4.05

In [None]:
# Categorize sites based on thresholds
df_urchin_decade['resilience_category'] = pd.cut(df_urchin_decade['number'], 
                                           bins=[-float('inf'), threshold[0], threshold[1], threshold[2], float('inf')], 
                                           labels=['Strong','Moderate', 'Low', 'None'])


In [None]:
df_urchin_decade 

### urchin number histogram along lat

In [None]:
n_bins = 20
# Create an array from -25 to -45 for binning
latitude_bins = np.linspace(-45, -25, n_bins + 1)
# Create bins using the predefined latitude array
df_urchin_decade['bin'], bins = pd.cut(df_urchin_decade['latitude'], bins=latitude_bins, retbins=True, right=False)
all_bins = pd.DataFrame({'bin': pd.IntervalIndex.from_breaks(bins, closed='left')})
# all_bins

In [None]:
# urchin threshold for kelp collapse under three mhw conditions
threshold = [1.7, 2.4, 3.1]  # 1.7, 2.95, 4.05

In [None]:
# Categorize sites based on thresholds
df_urchin_decade['resilience_category'] = pd.cut(df_urchin_decade['number'], 
                                           bins=[-float('inf'), threshold[0], threshold[1], threshold[2], float('inf')], 
                                           labels=['High','Moderate', 'Low', 'Minimal'])


In [None]:
## we want all possible categories (even if they have zero occurrences in some bins), use observed=False!
binned_df = df_urchin_decade.groupby(['bin', 'resilience_category'], 
                                     observed=False).size().unstack(fill_value=0).reset_index() 
df_number_binned = df_urchin_decade.groupby('bin', observed=False).max(numeric_only = True).reset_index()
df_number_binned
# Merge with all_bins to ensure all bins are present
binned_df = binned_df.merge(all_bins, on='bin', how='left').merge(df_number_binned[['bin', 'number']], on = 'bin', how = 'left')

In [None]:
binned_df['latitude'] = binned_df['bin'].apply(lambda x: x.mid).round(1)
binned_df

In [None]:
cross_tab_prop = pd.crosstab(index=df_urchin_decade['bin'],
                             columns=df_urchin_decade['resilience_category'],
                             normalize="index")
# Reindex to include all bins, filling missing values with 0
cross_tab_prop = cross_tab_prop.reindex(binned_df['latitude'], fill_value=0)
cross_tab_prop

In [None]:
cmap_disc = mcolors.ListedColormap(cmc.navia(np.linspace(0.1, 0.9, 4)))
cmap_disc

In [None]:
# fig, axs = plt.subplots(1, 2, figsize=(10, 6))
fig = plt.figure(figsize=(10, 6))
proj = ccrs.PlateCarree()

# Add the first subplot with a projection
axs[0] = fig.add_subplot(1, 2, 1, projection=ccrs.PlateCarree()) 
axs[1] = fig.add_subplot(1, 2, 2) 

## sort first!!
df_urchin_decade = df_urchin_decade.sort_values(by=['number'])
levels = [0, 1.7, 2.4, 3.1, 5]  # 0, 4, 8, 20, 65
norm = mcolors.BoundaryNorm(levels, ncolors=cmap_disc.N, extend='neither') 
scatter = axs[0].scatter(np.array(df_urchin_decade.longitude), np.array(df_urchin_decade.latitude), 
                     c=df_urchin_decade.number , cmap=cmap_disc, norm = norm, marker='o', s=45, transform=proj, alpha = 0.9) #cmap_gradual cmap_disc, vmin=0, vmax=4

# # Create a new axis for the colorbar at the bottom
# cbar_ax = fig.add_axes([0.15, 0.02, 0.3, 0.025]) #[0.02, 0.2, 0.02, 0.5] 
# # Add the colorbar to the new axis
# cbar = fig.colorbar(scatter, cax=cbar_ax, orientation='horizontal', pad=0.01, extend='max')
# # cbar.ax.yaxis.set_label_position('left')
# # cbar.ax.yaxis.set_ticks_position('left')
# cbar.set_label('Urchin density(/m2)', fontsize=12, rotation=0, labelpad=5)
# cbar.set_ticks(levels) # np.linspace(0, 4, num=5)
# ticklabs = cbar.ax.get_xticklabels()
# cbar.ax.set_xticklabels(ticklabs, fontsize=14)

# axs[0].coastlines()
axs[0].gridlines(alpha=0.3)
# ax1.add_feature(cfeature.OCEAN, facecolor='lightsteelblue')
axs[0].add_feature(cfeature.LAND, facecolor='lightgrey')

# Set the desired x and y range
lon_range = [144, 158]
lat_range = [-45, -25]

# Set the x and y axis limits
axs[0].set_xlim(lon_range)
axs[0].set_ylim(lat_range)
# Define the tick format functions
lat_formatter = ticker.FuncFormatter(lambda x, pos: '{:}°{}'.format(abs(x), 'S' if x < 0 else 'N'))
lon_formatter = ticker.FuncFormatter(lambda x, pos: '{:}°{}'.format(abs(x), 'W' if x < 0 else 'E'))
# Set the x and y tick positions and labels
axs[0].set_xticks(np.arange(145, 155, 5), crs=ccrs.PlateCarree())
axs[0].set_xticklabels([lon_formatter(x) for x in axs[0].get_xticks()], fontsize=10)
axs[0].set_yticks(np.arange(lat_range[0], lat_range[1]+1, 5), crs=ccrs.PlateCarree())
axs[0].set_yticklabels([lat_formatter(y) for y in axs[0].get_yticks()], fontsize=10)

## ===================================sketch a line =================================
divider = make_axes_locatable(axs[0])
ax_plot = divider.append_axes("right", size="20%", pad=-0.7, axes_class=plt.Axes)  # Size & spacing

# Plot the latitudinal outline
ax_plot.plot(y_edges, x_edges, drawstyle='steps-mid', color='brown', linewidth=1.2)
ax_plot.set_ylim([-45, -25])
ax_plot.set_xlabel(f"Sea urchin\ndensity(/m\u00b2)", fontsize = 11)
ax_plot.set_xlim([0, 10])
ax_plot.set_xticks([0, 5, 10])

# ax_plot.axis('off')
ax_plot.xaxis.set_ticks_position('top')
ax_plot.xaxis.set_label_position('top')
# ax_plot.spines['top'].set_visible(False)
ax_plot.spines['bottom'].set_visible(False)
ax_plot.spines['left'].set_visible(True)
ax_plot.spines['right'].set_visible(False)
# ax_plot.set_xticks([])
ax_plot.set_yticks([])
## ==================================================================================


# =================== Add the second subplot without a projection
axs[1].set_position([0.5, 0.11, 0.3, 0.77])
cross_tab_prop.plot(kind='barh', 
                    stacked=True, 
                    width=1, 
                    colormap=custom_cmap, #.reversed(),
                    edgecolor='lightgray', 
                    ax = axs[1])

# axs[1].legend(loc="center right", ncol=1, bbox_to_anchor=(1.75, 0.5), fontsize = 11,
#               title="Kelp resilience levels\nto maximum intensity",  title_fontsize=12)

axs[1].legend().remove()
axs[1].set_xlabel("Proportion of sites by kelp\nresilience to MHW", fontsize = 12) # sea urchin grazing risk
axs[1].xaxis.set_ticks_position('top')
axs[1].xaxis.set_label_position('top')
axs[1].set_ylabel("")
axs[1].set_yticks([])

## ------------- legend -----------------------------------
colors = [cmap_disc(i/4) for i in range(4)]
top_labels = ["[0, 1.7)", "[1.7, 2.4)", "[2.4, 3.1)", '≥ 3.1']
bottom_labels = ["High", "Moderate", "Low", 'Minimal']

patches = [mpatches.Patch(color=c) for c in colors]

# Create dummy legend first
legend = plt.legend(handles=patches, loc="lower center", ncol=4, bbox_to_anchor=(2.5, -0.22), handlelength=4.5, handleheight=1.5)

# Add top labels
for i, txt in enumerate(top_labels):
    plt.text(0.4 + i*0.104, 0.01, txt, ha='center', va='bottom', fontsize=11, transform=plt.gcf().transFigure)

# Add bottom labels
for i, txt in enumerate(bottom_labels):
    plt.text(0.4 + i*0.104, -0.06, txt, ha='center', va='top', fontsize=11, transform=plt.gcf().transFigure)

plt.text(0.26, 0.035, f'Sea urchin density(/m\u00b2)', ha='center', va='top', fontsize=11, weight='bold', transform=plt.gcf().transFigure)
plt.text(0.265, -0.057, 'Kelp resilience levels', ha='center', va='top', fontsize=11, weight='bold', transform=plt.gcf().transFigure)