In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np
import os
import cmocean as cm
import waypoint_distance as wd
import pandas as pd
from pathlib import Path
from datetime import datetime
from matplotlib.dates import DateFormatter
import gsw
import matplotlib.dates as mdates
%matplotlib widget

cube = xr.open_dataset(os.path.expanduser('~/Desktop/Summer 2025 Python/calvert_cube.nc'))
topo = xr.open_dataset(os.path.expanduser('~/Desktop/Summer 2025 Python/british_columbia_3_msl_2013.nc'))
ds = xr.open_dataset('~/Desktop/Summer 2025 Python/Hakai_calvert.nc')

def plot_ts_from_station(ds, station_id, years=[2023, 2024, 2025],
                         depth_range=(100, 450), xlim=None, ylim=None, target_months=None):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    import itertools
    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})

    # Filter and clean
    ds_station = ds.where(ds['station'] == station_id, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)
    if target_months is not None:
        ds_station = ds_station.where(ds_station['time'].dt.month.isin(target_months), drop=True)

    if ds_station.sizes['row'] < 2:
        print(f"No valid data for station {station_id} in years {years}.")
        return

    # ─── Interpolate profiles onto a uniform depth grid ───
    regular_depth = np.arange(0, 423, 1)
    sigma_theta = gsw.sigma0(ds_station['salinity'], ds_station['temperature'])
    times = np.unique(ds_station['time'].values)

    all_S, all_T = [], []
    cmap = plt.colormaps.get_cmap('jet')
    colors = cmap(np.linspace(0, 1, len(times)))
    marker = itertools.cycle(['o', 'D', 's', 'X', 'P', '^', 'v', '<', '>', '*', 'h'])

    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    ax.set_facecolor('lightgrey')

    # Sigma-theta grid
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 200),
                                 np.linspace(5.3, 20, 200))
    sigma = gsw.sigma0(S_grid, T_grid)

    for levels, color, lw in [
            (np.linspace(24, 27, 7), 'black', .51),
                        ([25.6], 'white', 0.5),
                        ([25.7], 'lime', 0.5),
                        ([25.8], 'red', 0.5),
                        ([25.9], 'blue', 0.5),
                        ([26.0], 'black', 0.51),
                        ([26.1], 'purple', 0.5),
                        ([26.2], 'salmon', 0.5),
                        ([26.3], 'yellow', 0.5),
                        ([26.4], 'cyan', 0.5)]:
        cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                        colors=color, linewidths=lw, linestyles='--')
        if lw != 0.5:
            ax.clabel(cs, fmt='%1.2f', fontsize=9, inline=True)

    # Plot each profile
    for i, t in enumerate(times):
        profile = ds_station.where(ds_station['time'] == t, drop=True)
        if profile.sizes['row'] < 2:
            continue

        depth = profile['depth'].values
        temp = profile['temperature'].values
        sal = profile['salinity'].values

        mask = (depth >= depth_range[0]) & (depth <= depth_range[1])
        temp = temp[mask]
        sal = sal[mask]

        temp_masked = ~np.isnan(temp) & ~np.isnan(sal)
        if np.any(temp_masked):
            ax.scatter(sal[temp_masked], temp[temp_masked],
                       color=colors[i],
                       s=90,
                       marker=next(marker),
                       linewidth=0.2,
                       label=pd.Timestamp(t).strftime('%d %b %Y'))

            all_S.extend(sal[temp_masked])
            all_T.extend(temp[temp_masked])

    # Axis limits
    if all_S and all_T:
        all_S = np.array(all_S)
        all_T = np.array(all_T)
        if xlim is None:
            xlim = (np.nanmin(all_S) - 0.1, np.nanmax(all_S) + 0.1)
        if ylim is None:
            ylim = (np.nanmin(all_T) - 0.1, np.nanmax(all_T) + 0.1)

    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel(r"$\theta$ (°C)")
    ax.set_title(f"T–S Diagram at {station_id}")
    ax.legend(fontsize=9, loc='upper right')
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    plt.tight_layout()

In [None]:
def plot_ts_shelf_basin(cube, basin_range=(5000, 15000), sill_range=(23500, 30000), shelf_range=(65000, 70000),
                        target_years=[2024], target_months=[5], which_date='first', xlim=None, ylim=None):
    """
    T–S diagram contrasting basin vs sill vs shelf water.
    Uses the first or second transect in the selected month per region.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar

    assert which_date in ['first', 'second'], "which_date must be 'first' or 'second'"

    plt.rcParams.update({
        "xtick.labelsize": 30,
        "ytick.labelsize": 30,
        "axes.titlesize": 30,
        "axes.labelsize": 30,
        "legend.fontsize": 20,
        "figure.titlesize": 40})

    cube = cube.load()
    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    ax.set_facecolor('grey')

    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)

    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    for levels, color, lw in [
        (np.linspace(23, 27, 9), 'black', 0.5),
        ([25.6], 'white', 0.5),
        ([25.7], 'lime', 0.5),
        ([25.8], 'red', 0.5),
        ([25.9], 'blue', 0.5),
        ([26.0], 'black', 0.5),
        ([26.1], 'purple', 0.5),
        ([26.2], 'salmon', 0.5),
        ([26.3], 'yellow', 0.5),
        ([26.4], 'cyan', 0.5)
    ]:
        cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                        colors=color, linewidths=lw, linestyles='--')
        for level in levels:
            if level in manual_labels:
                ax.clabel(cs, fmt='%1.2f', fontsize=10, inline=True,
                          manual=[manual_labels[level]])

    for (start, end), label, color, marker in [
        ((basin_range), 'Basin water (5–15 km)', 'cyan', 'o'),
        ((sill_range), 'Sill water (23.5–30 km)', 'red', 'o'),
        ((shelf_range), 'QCS water (65–75 km)', 'salmon', 'o'),
    ]:
        subset = cube.sel(along=slice(start, end))
        times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
        selected_idxs = [i for i, t in enumerate(times)
                         if t.year in target_years and t.month in target_months]

        if selected_idxs:
            selected_idxs.sort()
            idx = selected_idxs[0] if which_date == 'first' else selected_idxs[1] if len(selected_idxs) > 1 else selected_idxs[0]
            small = subset.isel(transect=idx)[['temperature', 'salinity']]
            small.load()

            T = small['temperature'].values.ravel()
            S = small['salinity'].values.ravel()
            mask = ~np.isnan(T) & ~np.isnan(S)
            ax.scatter(S[mask], T[mask], s=20, color=color, label=label, marker=marker)

    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel(r"$\theta$ (°C)")
    ax.set_title(f"T–S Diagram: Basin vs Sill vs Shelf Water ({which_date.capitalize()} transect)")
    ax.legend()
    if xlim: ax.set_xlim(xlim)
    if ylim: ax.set_ylim(ylim)
    plt.tight_layout()

# plot_ts_shelf_basin(
#     cube,
#     basin_range=(5000, 15000),
#     sill_range=(23500, 30000),
#     shelf_range=(65000, 75000),
#     target_years=[2024],
#     target_months=[5],
#     which_date='first',
#     xlim=(30, 34),
#     ylim=(5, 12)
# )

# plot_ts_shelf_basin(
#     cube,
#     basin_range=(5000, 15000),
#     sill_range=(23500, 30000),
#     shelf_range=(65000, 75000),
#     target_years=[2024],
#     target_months=[5],
#     which_date='second',
#     xlim=(30, 34),
#     ylim=(5, 12)
# )

# Some crazy thinking that may lead to cool results indicating mixing? #

In [None]:
def plot_ts_shelf_basin(cube, region_range=(0, 30000),
                        target_years=[2024], target_months=[5], which_date='first',
                        xlim=None, ylim=None):
    """
    T–S diagram with individual profiles colored by along-track distance using a colormap.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar
    from matplotlib.cm import get_cmap
    from matplotlib.colors import Normalize

    assert which_date in ['first', 'second'], "which_date must be 'first' or 'second'"

    plt.rcParams.update({
        "xtick.labelsize": 30,
        "ytick.labelsize": 30,
        "axes.titlesize": 30,
        "axes.labelsize": 30,
        "legend.fontsize": 20,
        "figure.titlesize": 40})

    cube = cube.load()
    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    ax.set_facecolor('grey')

    # ─── Isopycnals ─────────────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    for levels, color, lw in [
        (np.linspace(23, 27, 9), 'black', 0.5),
        ([25.6], 'white', 0.5),
        ([25.7], 'lime', 0.5),
        ([25.8], 'red', 0.5),
        ([25.9], 'blue', 0.5),
        ([26.0], 'black', 0.5),
        ([26.1], 'purple', 0.5),
        ([26.2], 'salmon', 0.5),
        ([26.3], 'yellow', 0.5),
        ([26.4], 'cyan', 0.5)
    ]:
        cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                        colors=color, linewidths=lw, linestyles='--')
        for level in levels:
            if level in manual_labels:
                ax.clabel(cs, fmt='%1.2f', fontsize=10, inline=True,
                          manual=[manual_labels[level]])

    # ─── Color by Along-Track Distance ─────────────────────────────
    subset = cube.sel(along=slice(*region_range))
    times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
    selected_idxs = [i for i, t in enumerate(times)
                     if t.year in target_years and t.month in target_months]

    if selected_idxs:
        selected_idxs.sort()
        idx = selected_idxs[0] if which_date == 'first' else selected_idxs[1] if len(selected_idxs) > 1 else selected_idxs[0]
        small = subset.isel(transect=idx)[['temperature', 'salinity']]
        small.load()

        along_vals = small['along'].values
        cmap = get_cmap('jet')
        norm = Normalize(vmin=region_range[0], vmax=region_range[1])

        for along in along_vals:
            T = small['temperature'].sel(along=along).values
            S = small['salinity'].sel(along=along).values
            mask = ~np.isnan(T) & ~np.isnan(S)
            if np.any(mask):
                color = cmap(norm(along))
                ax.scatter(S[mask], T[mask], s=20, color=color, marker='o')

        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax, orientation='vertical', label='Along-transect distance (m)')

    # ─── Axis and Layout ─────────────────────────────
    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel(r"$\theta$ (°C)")
    ax.set_title(f"T–S Diagram ({which_date.capitalize()} transect)")
    if xlim: ax.set_xlim(xlim)
    if ylim: ax.set_ylim(ylim)
    plt.tight_layout()

In [None]:
def plot_ts_shelf_basin_by_transect(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    xlim=(30.5, 34),
    ylim=(5, 16),
    ncols=4,
):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar
    from matplotlib.cm import get_cmap
    from matplotlib.colors import Normalize
    import os

    plt.rcParams.update({
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "axes.titlesize": 14,
        "axes.labelsize": 12,
        "legend.fontsize": 10,
        "figure.titlesize": 16})

    cube = cube.load()
    subset = cube.sel(along=slice(*region_range))
    times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
    selected_idxs = [i for i, t in enumerate(times)
                     if t.year in target_years and t.month in target_months]

    n_panels = len(selected_idxs)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
                             figsize=(4.5 * ncols, 4 * nrows), squeeze=False)
    axes = axes.ravel()

    # Isopycnal grid
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9,
                        26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    cmap = plt.colormaps['jet']
    norm = Normalize(vmin=region_range[0], vmax=region_range[1])

    for i, idx in enumerate(selected_idxs):
        ax = axes[i]
        ax.set_facecolor('grey')

        # Plot isopycnals
        for levels, color, lw in [
            (np.linspace(23, 27, 9), 'black', 0.5),
            ([25.6], 'white', 0.5),
            ([25.7], 'lime', 0.5),
            ([25.8], 'red', 0.5),
            ([25.9], 'blue', 0.5),
            ([26.0], 'black', 0.5),
            ([26.1], 'purple', 0.5),
            ([26.2], 'salmon', 0.5),
            ([26.3], 'yellow', 0.5),
            ([26.4], 'cyan', 0.5)
        ]:
            cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                            colors=color, linewidths=lw, linestyles='--')
            for level in levels:
                if level in manual_labels:
                    ax.clabel(cs, fmt='%1.2f', fontsize=7, inline=False,
                              manual=[manual_labels[level]])

        # Preload transect
        ds = subset.isel(transect=idx)
        temp = ds['temperature']
        sali = ds['salinity']
        along_vals = ds['along'].values

        # Scatter all points by along position
        for a_idx, along in reversed(list(enumerate(along_vals))):
            T_col = temp.isel(along=a_idx).values
            S_col = sali.isel(along=a_idx).values
            mask = ~np.isnan(T_col) & ~np.isnan(S_col)
            if np.any(mask):
                ax.scatter(S_col[mask], T_col[mask], s=1.25, color=cmap(norm(along)), marker='o')
        from matplotlib.collections import PathCollection

        for coll in ax.collections:
            if isinstance(coll, PathCollection):  # Only rasterize scatter points
                coll.set_rasterized(True)
                
        ax.set_title(times[idx].strftime('%d %b %Y'))

        # Ticks only on outer edge
        if i % ncols == 0:
            ax.set_ylabel("$\\theta$ (°C)")
        else:
            ax.set_yticklabels([])
            ax.set_ylabel("")

        if i // ncols == nrows - 1:
            ax.set_xlabel("Salinity (psu)")
        else:
            ax.set_xticklabels([])
            ax.set_xlabel("")

        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

    # Remove unused axes
    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    # Add colorbar outside plot
    fig.subplots_adjust(right=0.88, top=0.92)
    cbar_ax = fig.add_axes([0.9, 0.1, 0.02, 0.8])
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax)
    cbar.set_label('Along-transect distance (m)', fontsize=12)

    fig.suptitle("T–S Diagrams Colored by Along Distance", fontsize=24)

    # Rasterize all axes and save at 600 DPI
    for ax in fig.axes:
        ax.set_rasterization_zorder(1)
    output_path = os.path.expanduser("~/Desktop/ts_grid_raster600dpi.pdf")
    # fig.savefig(output_path, format='pdf', dpi=600, bbox_inches='tight')

plot_ts_shelf_basin_by_transect(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    xlim=(30.5, 34),
    ylim=(5, 16),
    ncols=4)

In [None]:
def plot_ts_shelf_basin_by_transect(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    xlim=(30.5, 34),
    ylim=(5, 16),
    ncols=4,
    inset_months=[4, 5, 6, 7, 8, 9],
    inset_limits_by_date=None,
):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar
    from matplotlib.colors import Normalize
    import os

    plt.rcParams.update({
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "axes.titlesize": 14,
        "axes.labelsize": 12,
        "legend.fontsize": 10,
        "figure.titlesize": 16})

    cube = cube.load()
    subset = cube.sel(along=slice(*region_range))
    times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
    selected_idxs = [i for i, t in enumerate(times) if t.year in target_years and t.month in target_months]
    n_panels = len(selected_idxs)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
                            figsize=(4.5 * ncols, 4 * nrows), squeeze=False)

    fig.subplots_adjust(
        wspace=0.1,
        hspace=0.1)
    axes = axes.ravel()

    # Isopycnal grid
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8,
                        25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    cmap = plt.colormaps['jet']
    norm = Normalize(vmin=region_range[0], vmax=region_range[1])

    for i, idx in enumerate(selected_idxs):
        ax = axes[i]
        ax.set_facecolor('grey')

        for levels, color, lw in [
            (np.linspace(23, 27, 9), 'black', 0.5),
            ([25.6], 'white', 0.5), ([25.7], 'lime', 0.5),
            ([25.8], 'red', 0.5), ([25.9], 'blue', 0.5),
            ([26.0], 'black', 0.5), ([26.1], 'purple', 0.5),
            ([26.2], 'salmon', 0.5), ([26.3], 'yellow', 0.5),
            ([26.4], 'cyan', 0.5)]:
            cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                            colors=color, linewidths=lw, linestyles='--')
            for level in levels:
                if level in manual_labels:
                    ax.clabel(cs, fmt='%1.2f', fontsize=7, inline=False,
                              manual=[manual_labels[level]])

        ds = subset.isel(transect=idx)
        temp = ds['temperature']
        sali = ds['salinity']
        along_vals = ds['along'].values

        for a_idx, along in reversed(list(enumerate(along_vals))):
            T_col = temp.isel(along=a_idx).values
            S_col = sali.isel(along=a_idx).values
            mask = ~np.isnan(T_col) & ~np.isnan(S_col)
            if np.any(mask):
                ax.scatter(S_col[mask], T_col[mask], s=1.25, color=cmap(norm(along)), marker='o')

        ax.set_title(times[idx].strftime('%d %b %Y'))
        if i % ncols == 0:
            ax.set_ylabel("$\\theta$ (\u00b0C)")
        else:
            ax.set_yticklabels([])
        if i // ncols == nrows - 1:
            ax.set_xlabel("Salinity (psu)")
        else:
            ax.set_xticklabels([])
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

        # ─── Add Inset ─────────────────────────────
        if times[idx].month in inset_months:
            inset = ax.inset_axes([0.5, 0.5, 0.48, 0.48])
            inset.set_facecolor('grey')
            inset_xlim_local = (33.3, 33.8)
            inset_ylim_local = (6.0, 7.1)

            date_str = times[idx].strftime('%Y-%m-%d')
            if inset_limits_by_date and date_str in inset_limits_by_date:
                inset_xlim_local = inset_limits_by_date[date_str].get('xlim', inset_xlim_local)
                inset_ylim_local = inset_limits_by_date[date_str].get('ylim', inset_ylim_local)

            for levels, color, lw in [
                (np.linspace(23, 27, 9), 'black', 0.5),
                ([25.6], 'white', 0.5), ([25.7], 'lime', 0.5),
                ([25.8], 'red', 0.5), ([25.9], 'blue', 0.5),
                ([26.0], 'black', 0.5), ([26.1], 'purple', 0.5),
                ([26.2], 'salmon', 0.5), ([26.3], 'yellow', 0.5),
                ([26.4], 'cyan', 0.5)]:
                cs = inset.contour(S_grid, T_grid, sigma, levels=levels,
                                   colors=color, linewidths=lw, linestyles='--')
                for level in levels:
                    if level in manual_labels:
                        inset.clabel(cs, fmt='%1.2f', fontsize=6, inline=False,
                                     manual=[manual_labels[level]])

            for a_idx, along in reversed(list(enumerate(along_vals))):
                T_col = temp.isel(along=a_idx).values
                S_col = sali.isel(along=a_idx).values
                mask = ~np.isnan(T_col) & ~np.isnan(S_col)
                if np.any(mask):
                    inset.scatter(S_col[mask], T_col[mask], s=0.5, color=cmap(norm(along)), marker='o')

            inset.set_xlim(*inset_xlim_local)
            inset.set_ylim(*inset_ylim_local)
            inset.set_xticks([])
            inset.set_yticks([])

            from matplotlib.patches import Rectangle
            box = Rectangle(
                (inset_xlim_local[0], inset_ylim_local[0]),
                inset_xlim_local[1] - inset_xlim_local[0],
                inset_ylim_local[1] - inset_ylim_local[0],
                edgecolor='black', facecolor='none', linestyle='--', linewidth=1)
            ax.add_patch(box)

    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    fig.subplots_adjust(right=0.88, top=0.92)
    cbar_ax = fig.add_axes([0.9, 0.1, 0.02, 0.8])
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax)
    cbar.set_label('Along-transect distance (m)', fontsize=12)
    fig.suptitle("T–S Diagrams Colored by Along Distance", fontsize=24)

    for ax in fig.axes:
        ax.set_rasterization_zorder(1)

plot_ts_shelf_basin_by_transect(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    xlim=(30.5, 34),
    ylim=(5, 16),
    ncols=4,
    inset_months=[1,2,3,4, 5, 6, 7, 8, 9,10,11,12],
    inset_limits_by_date={
        '2024-04-16': {'xlim': (32.75, 33.5), 'ylim': (7.3, 8.5)},
        '2024-05-10': {'xlim': (32.7, 33.5), 'ylim': (7.2, 8.1)},
        '2024-05-16': {'xlim': (32.9, 33.6), 'ylim': (6.8, 7.9)},
        '2024-07-17': {'xlim': (33.3, 33.8), 'ylim': (6, 7.1)},
        '2024-07-23': {'xlim': (33.3, 33.8), 'ylim': (6, 7.1)},
    }
)

In [None]:
def plot_ts_histogram_by_transect(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    sal_bins=np.linspace(30.5, 34, 200),
    temp_bins=np.linspace(5, 16, 200),
    ncols=4,
):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    from matplotlib.colors import LogNorm
    from matplotlib.cm import ScalarMappable
    import os

    plt.rcParams.update({
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "axes.titlesize": 14,
        "axes.labelsize": 12,
        "legend.fontsize": 10,
        "figure.titlesize": 16})

    cube = cube.load()
    subset = cube.sel(along=slice(*region_range))
    times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
    selected_idxs = [i for i, t in enumerate(times)
                     if t.year in target_years and t.month in target_months]

    n_panels = len(selected_idxs)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
                             figsize=(4.5 * ncols, 4 * nrows), squeeze=False)
    axes = axes.ravel()

    for i, idx in enumerate(selected_idxs):
        ax = axes[i]
        ax.set_facecolor('grey')

        ds = subset.isel(transect=idx)
        temp = ds['temperature'].values
        sal = ds['salinity'].values

        mask = ~np.isnan(temp) & ~np.isnan(sal)
        if not np.any(mask):
            ax.axis('off')
            continue

        T_vals = temp[mask]
        S_vals = sal[mask]
        hist, xedges, yedges = np.histogram2d(S_vals, T_vals, bins=[sal_bins, temp_bins])

        pcm = ax.pcolormesh(xedges, yedges, hist.T, cmap='inferno',
                            norm=LogNorm(vmin=1, vmax=1000), rasterized=True)

        ax.set_title(times[idx].strftime('%d %b %Y'), fontsize=10)
        ax.set_xlim(sal_bins[0], sal_bins[-1])
        ax.set_ylim(temp_bins[0], temp_bins[-1])

        if i % ncols == 0:
            ax.set_ylabel("Temperature (°C)")
        else:
            ax.set_yticklabels([])
            ax.set_ylabel("")

        if i // ncols == nrows - 1:
            ax.set_xlabel("Salinity (psu)")
        else:
            ax.set_xticklabels([])
            ax.set_xlabel("")

        ax.grid(True)

    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    fig.subplots_adjust(right=0.88, top=0.92)
    cbar_ax = fig.add_axes([0.9, 0.1, 0.02, 0.8])
    sm = ScalarMappable(cmap='inferno', norm=LogNorm(vmin=1, vmax=1000))
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax)
    cbar.set_label('Log-scaled Count', fontsize=12)

    fig.suptitle("T–S Histograms by Transect", fontsize=24)

    for ax in fig.axes:
        ax.set_rasterization_zorder(1)
    # output_path = os.path.expanduser("~/Desktop/ts_histograms.pdf")
    # fig.savefig(output_path, format='pdf', dpi=600, bbox_inches='tight')
    plt.show()

plot_ts_histogram_by_transect(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    sal_bins=np.linspace(30.5, 34, 400),
    temp_bins=np.linspace(5, 16, 400),
    ncols=4
)

In [None]:
def plot_ts_histogram_all_transects(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    sal_bins=np.linspace(30.5, 34, 200),
    temp_bins=np.linspace(5, 16, 200),
):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    from matplotlib.colors import LogNorm
    import os

    plt.rcParams.update({
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "axes.titlesize": 16,
        "axes.labelsize": 14,
        "legend.fontsize": 12,
        "figure.titlesize": 18})

    cube = cube.load()
    subset = cube.sel(along=slice(*region_range))
    times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
    selected_idxs = [i for i, t in enumerate(times)
                     if t.year in target_years]

    all_T = []
    all_S = []

    for idx in selected_idxs:
        ds = subset.isel(transect=idx)
        temp = ds['temperature'].values
        sal = ds['salinity'].values
        mask = ~np.isnan(temp) & ~np.isnan(sal)
        all_T.append(temp[mask])
        all_S.append(sal[mask])

    if not all_T or not all_S:
        print("No valid temperature/salinity data.")
        return

    T = np.concatenate(all_T)
    S = np.concatenate(all_S)
    print(f"Total valid T–S points: {len(T)}")

    hist, xedges, yedges = np.histogram2d(S, T, bins=[sal_bins, temp_bins])

    fig, ax = plt.subplots(figsize=(8, 6))
    pcm = ax.pcolormesh(xedges, yedges, hist.T, cmap='inferno',
                        norm=LogNorm(vmin=1, vmax=10000), rasterized=True)

    # ax.set_xlim(sal_bins[0], sal_bins[-1])
    # ax.set_ylim(temp_bins[0], temp_bins[-1])
    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel("Temperature (°C)")
    ax.set_title("T–S Histogram (2024–2025, 0–30 km)")
    ax.set_facecolor('grey')

    cbar = fig.colorbar(pcm, ax=ax)
    cbar.set_label("Log-scaled Count")

    fig.tight_layout()
    # output_path = os.path.expanduser("~/Desktop/ts_hist_all.pdf")
    # fig.savefig(output_path, format='pdf', dpi=600, bbox_inches='tight')

    plt.show()
    print(f"Max bin count: {hist.max()}")

plot_ts_histogram_all_transects(
    cube,
    region_range=(0, 30000),
    target_years=[2020, 2021, 2022, 2023, 2024, 2025],
    sal_bins = np.arange(28, 34.01, 0.01),
    temp_bins = np.arange(5, 16.01, 0.01),
    )

In [None]:
def plot_ts_and_o2_density_histograms_side_by_side(
    cube,
    region_range=(0, 30000),
    target_years=[2020, 2021, 2022, 2023, 2024, 2025],
    sal_bins=np.arange(28, 34.01, 0.01),
    temp_bins=np.arange(5, 16.01, 0.01),
    dens_bins=np.arange(23, 27.01, 0.01),
    o2_bins=np.arange(0, 400.01, 1),
):
    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib.colors import LogNorm
    from matplotlib.cm import ScalarMappable
    import gsw
    from scipy.optimize import root_scalar

    plt.rcParams.update({
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "axes.titlesize": 16,
        "axes.labelsize": 14,
        "legend.fontsize": 12,
        "figure.titlesize": 18})

    cube = cube.load()
    subset = cube.sel(along=slice(*region_range))
    times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
    selected_idxs = [i for i, t in enumerate(times) if t.year in target_years]

    all_T, all_S, all_sigma, all_o2 = [], [], [], []

    for idx in selected_idxs:
        ds = subset.isel(transect=idx)
        temp = ds['temperature'].values
        sal = ds['salinity'].values
        sigma = ds['potential_density'].values - 1000
        o2 = ds['oxygen_concentration'].values

        ts_mask = ~np.isnan(temp) & ~np.isnan(sal)
        o2_mask = ~np.isnan(o2) & ~np.isnan(sigma)

        all_T.append(temp[ts_mask])
        all_S.append(sal[ts_mask])
        all_sigma.append(sigma[o2_mask])
        all_o2.append(o2[o2_mask])

    if not all_T or not all_o2:
        print("Missing data.")
        return

    T = np.concatenate(all_T)
    S = np.concatenate(all_S)
    SIGMA = np.concatenate(all_sigma)
    O2 = np.concatenate(all_o2)

    print(f"Total valid T–S points: {len(T)}")
    print(f"Total valid O₂–σ₀ points: {len(O2)}")

    hist_ts, xedges_ts, yedges_ts = np.histogram2d(S, T, bins=[sal_bins, temp_bins])
    hist_o2, xedges_o2, yedges_o2 = np.histogram2d(O2, SIGMA, bins=[o2_bins, dens_bins])

    log_norm = LogNorm(vmin=1, vmax=7500)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6),
                                   gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.15})
    ax1.set_facecolor('grey')
    ax2.set_facecolor('grey')

    pcm1 = ax1.pcolormesh(xedges_ts, yedges_ts, hist_ts.T,
                          cmap='inferno', norm=log_norm, rasterized=True)
    pcm2 = ax2.pcolormesh(xedges_o2, yedges_o2, hist_o2.T,
                          cmap='inferno', norm=log_norm, rasterized=True)

    # ─── Isopycnals on T–S plot ─────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 34, 300),
                                 np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    isopycnal_levels = np.arange(20, 27, 0.5)

    cs1 = ax1.contour(S_grid, T_grid, sigma, levels=isopycnal_levels,
                      colors='black', linewidths=0.5, linestyles='--',  alpha = 0.7)

    # ─── Label isopycnals at the bottom ─────────────────────
    T_label = temp_bins[0] + 0.1
    manual_labels = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        try:
            sol = root_scalar(func, bracket=[sal_bins[0], sal_bins[-1]], method='brentq')
            if sol.converged:
                manual_labels.append((sol.root, T_label))
        except ValueError:
            continue

    ax1.clabel(cs1, fmt='%1.1f', fontsize=8, inline=False, manual=manual_labels)
    for o2_val in np.arange(50, o2_bins[-1], 50):
        ax2.axvline(o2_val, color='black', linestyle='--', linewidth=0.5, alpha=0.7)

    # ─── Horizontal lines on O2–density plot ─────────────────────
    for iso in isopycnal_levels:
        ax2.axhline(iso, color='black', linestyle='--', linewidth=0.5, alpha = 0.7)
        ax2.text(x=o2_bins[-1] + 10, y=iso, s=f'{iso:.1f}', color='white',
                 va='center', fontsize=8)

    ax1.set_xlabel(r"Salinity (psu)")
    ax1.set_ylabel(r"$\theta$ (°C)")
    ax2.set_xlabel(r"$O_2$ ($\mu$mol/L)")
    ax2.set_ylabel(r"$\sigma_\theta$ (kg m$^{-3}$)")

    ax1.set_xlim(sal_bins[0], sal_bins[-1])
    ax1.set_ylim(temp_bins[0], temp_bins[-1])
    ax2.set_xlim(o2_bins[0], o2_bins[-1])
    ax2.set_ylim(dens_bins[-1], dens_bins[0])

    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    sm = ScalarMappable(norm=log_norm, cmap='inferno')
    sm.set_array([])
    fig.colorbar(sm, cax=cbar_ax).set_label("Log-scaled Count")

    fig.suptitle(r"$\theta$–$S$ and $O_2$–$\sigma_\theta$ (0–30 km, 2024 and 2025)", fontsize=20)

    print(f"Max T–S bin count: {hist_ts.max()}")
    print(f"Max O₂–σ₀ bin count: {hist_o2.max()}")

plot_ts_and_o2_density_histograms_side_by_side(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    sal_bins=np.arange(30, 34.01, 0.01),
    temp_bins=np.arange(5, 16.01, 0.01),
    dens_bins=np.arange(23, 27.01, 0.01),
    o2_bins=np.arange(0, 425.01, 1),
)

In [None]:
cube

In [None]:
def plot_ts_o2sigma_to2_histograms_side_by_side(
    cube,
    region_range=(0, 30000),
    target_years=[2020, 2021, 2022, 2023, 2024, 2025],
    sal_bins=np.arange(28, 34.01, 0.01),
    temp_bins=np.arange(5, 16.01, 0.01),
    dens_bins=np.arange(23, 27.01, 0.01),
    o2_bins=np.arange(0, 450.01, 1),
):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    from matplotlib.colors import LogNorm
    from matplotlib.cm import ScalarMappable
    import gsw
    import os
    from scipy.optimize import root_scalar

    plt.rcParams.update({
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "axes.titlesize": 16,
        "axes.labelsize": 14,
        "legend.fontsize": 12,
        "figure.titlesize": 18})

    cube = cube.load()
    subset = cube.sel(along=slice(*region_range))
    times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
    selected_idxs = [i for i, t in enumerate(times) if t.year in target_years]

    all_T, all_S, all_sigma, all_o2 = [], [], [], []
    all_T_to2, all_O2_to2 = [], []

    for idx in selected_idxs:
        ds = subset.isel(transect=idx)

        temp = ds['temperature'].values
        sal = ds['salinity'].values
        sigma = ds['potential_density'].values - 1000
        o2 = ds['oxygen_concentration'].values

        ts_mask = ~np.isnan(temp) & ~np.isnan(sal)
        o2_mask = ~np.isnan(o2) & ~np.isnan(sigma)
        to2_mask = ~np.isnan(temp) & ~np.isnan(o2)

        all_T.append(temp[ts_mask])
        all_S.append(sal[ts_mask])
        all_sigma.append(sigma[o2_mask])
        all_o2.append(o2[o2_mask])
        all_T_to2.append(temp[to2_mask])
        all_O2_to2.append(o2[to2_mask])

    if not all_T or not all_o2:
        print("Missing data.")
        return

    T = np.concatenate(all_T)
    S = np.concatenate(all_S)
    SIGMA = np.concatenate(all_sigma)
    O2 = np.concatenate(all_o2)
    T_to2 = np.concatenate(all_T_to2)
    O2_to2 = np.concatenate(all_O2_to2)

    print(f"Total valid T–S points: {len(T)}")
    print(f"Total valid O₂–σ₀ points: {len(O2)}")
    print(f"Total valid T–O₂ points: {len(T_to2)}")

    hist_ts, xedges_ts, yedges_ts = np.histogram2d(S, T, bins=[sal_bins, temp_bins])
    hist_o2, xedges_o2, yedges_o2 = np.histogram2d(O2, SIGMA, bins=[o2_bins, dens_bins])
    hist_to2, xedges_to2, yedges_to2 = np.histogram2d(O2_to2, T_to2, bins=[o2_bins, temp_bins])

    combined_max = max(hist_ts.max(), hist_o2.max(), hist_to2.max())
    log_norm = LogNorm(vmin=1, vmax=combined_max)

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 6),
                                        gridspec_kw={'width_ratios': [1, 1, 1], 'wspace': 0.2})
    ax1.set_facecolor('grey')
    ax2.set_facecolor('grey')
    ax3.set_facecolor('grey')

    # ─── T–S Histogram ─────────────────────
    pcm1 = ax1.pcolormesh(xedges_ts, yedges_ts, hist_ts.T,
                          cmap='inferno', norm=log_norm, rasterized=True)
    ax1.set_xlabel("Salinity (psu)")
    ax1.set_ylabel("Temperature (°C)")
    ax1.set_xlim(sal_bins[0], sal_bins[-1])
    ax1.set_ylim(temp_bins[0], temp_bins[-1])

    # ─── O₂–σ₀ Histogram ─────────────────────
    pcm2 = ax2.pcolormesh(xedges_o2, yedges_o2, hist_o2.T,
                          cmap='inferno', norm=log_norm, rasterized=True)
    ax2.set_xlabel("Oxygen (µmol/L)")
    ax2.set_ylabel("σ₀ (kg/m³)")
    ax2.set_xlim(o2_bins[0], o2_bins[-1])
    ax2.set_ylim(dens_bins[-1], dens_bins[0])  # Flip y-axis

    # ─── T–O₂ Histogram ─────────────────────
    pcm3 = ax3.pcolormesh(xedges_to2, yedges_to2, hist_to2.T,
                          cmap='inferno', norm=log_norm, rasterized=True)
    ax3.set_xlabel("Oxygen (µmol/L)")
    ax3.set_ylabel("Temperature (°C)")
    ax3.set_xlim(o2_bins[0], o2_bins[-1])
    ax3.set_ylim(temp_bins[0], temp_bins[-1])

    # ─── Isopycnals on T–S and σ₀ lines ─────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 34, 300),
                                 np.linspace(5.3, 20, 300))
    sigma_grid = gsw.sigma0(S_grid, T_grid)
    isopycnal_levels = np.arange(20, 27, 0.5)

    cs1 = ax1.contour(S_grid, T_grid, sigma_grid, levels=isopycnal_levels,
                      colors='black', linewidths=0.5, linestyles='--')
    # Place isopycnal labels at the bottom of the plot
    T_label = temp_bins[0] + 0.1  # just above lower temp bound
    S_label_positions = []

    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        try:
            sol = root_scalar(func, bracket=[sal_bins[0], sal_bins[-1]], method='brentq')
            if sol.converged:
                S_label_positions.append((sol.root, T_label))
            else:
                S_label_positions.append(None)
        except ValueError:
            S_label_positions.append(None)

    manual_labels = {
        iso: pos for iso, pos in zip(isopycnal_levels, S_label_positions) if pos is not None
    }

    ax1.clabel(cs1, fmt='%1.1f', fontsize=8, inline=False, manual=list(manual_labels.values()))

    for iso in isopycnal_levels:
        ax2.axhline(iso, color='black', linestyle='--', linewidth=0.5)
        ax2.text(x=o2_bins[-1] + 10, y=iso, s=f'{iso:.1f}', color='white',
                 va='center', fontsize=8)
    # ─── Shared Colorbar ─────────────────────
    cbar_ax = fig.add_axes([0.93, 0.15, 0.02, 0.7])
    sm = ScalarMappable(norm=log_norm, cmap='inferno')
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax)
    cbar.set_label("Log-scaled Count")

    fig.suptitle(r"$\theta$–$S$, $O_2$–$\sigma_\theta$, and $\theta$–$O_2$ Histograms (0–30 km, 2020–2025)", fontsize=20)
    plt.show()

    print(f"Max T–S bin count: {hist_ts.max()}")
    print(f"Max O₂–σ₀ bin count: {hist_o2.max()}")
    print(f"Max T–O₂ bin count: {hist_to2.max()}")

plot_ts_o2sigma_to2_histograms_side_by_side(
    cube,
    region_range=(0, 30000),
    target_years=[2020, 2021, 2022, 2023, 2024, 2025],
    sal_bins=np.arange(30, 34.01, 0.01),
    temp_bins=np.arange(5, 16.01, 0.01),
    dens_bins=np.arange(22, 27.01, 0.01),
    o2_bins=np.arange(0, 425.01, 1),
)

In [None]:
def plot_ts_profile_averages(cube, basin_range=(5000, 15000), sill_range=(23500, 30000), shelf_range=(65000, 70000),
                             target_years=[2024], target_months=[5], xlim=None, ylim=None):
    """
    T–S diagram with mean salinity and temperature profiles for basin, sill, and shelf regions.
    Averages are computed over depth at each region, and plotted as lines.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar

    plt.rcParams.update({
        "xtick.labelsize": 30,
        "ytick.labelsize": 30,
        "axes.titlesize": 30,
        "axes.labelsize": 30,
        "legend.fontsize": 20,
        "figure.titlesize": 40})

    cube = cube.load()
    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    ax.set_facecolor('grey')

    # ─── Isopycnal Grid ───────────────────────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)

    # Manual isopycnal labels at fixed θ
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    # Contour and label isopycnals
    for levels, color, lw in [
        (np.linspace(23, 27, 9), 'black', 0.5),
        ([25.6], 'white', 0.5),
        ([25.7], 'lime', 0.5),
        ([25.8], 'red', 0.5),
        ([25.9], 'blue', 0.5),
        ([26.0], 'black', 0.5),
        ([26.1], 'purple', 0.5),
        ([26.2], 'salmon', 0.5),
        ([26.3], 'yellow', 0.5),
        ([26.4], 'cyan', 0.5)
    ]:
        cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                        colors=color, linewidths=lw, linestyles='--')
        for level in levels:
            if level in manual_labels:
                ax.clabel(cs, fmt='%1.2f', fontsize=10, inline= False, inline_spacing=0.1, manual=[manual_labels[level]])

    # ─── Mean Profiles by Region ────────────────────────────────
    for (start, end), label, color in [
        ((basin_range), 'Basin water (5–15 km)', 'cyan'),
        ((sill_range), 'Sill water (23.5–30 km)', 'red'),
        ((shelf_range), 'QCS water (65–75 km)', 'salmon'),
    ]:
        subset = cube.sel(along=slice(start, end))
        times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
        selected_idxs = [i for i, t in enumerate(times)
                         if t.year in target_years and t.month in target_months]

        if not selected_idxs:
            continue

        small = subset.isel(transect=selected_idxs)[['temperature', 'salinity']]
        small.load()

        T = small['temperature'].mean(dim=('transect', 'along')).values
        S = small['salinity'].mean(dim=('transect', 'along')).values
        mask = ~np.isnan(T) & ~np.isnan(S)
        ax.scatter(S[mask], T[mask], label=label, color=color, s = 20)

    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel(r"$\theta$ (°C)")
    ax.set_title("Mean T–S Profiles by Region")
    ax.legend()
    if xlim: ax.set_xlim(xlim)
    if ylim: ax.set_ylim(ylim)
    plt.tight_layout()

# Example call:
plot_ts_profile_averages(
    cube,
    basin_range=(5000, 15000),
    sill_range=(23500, 30000),
    shelf_range=(65000, 75000),
    target_years=[2024],
    target_months=[5],
    xlim=(30.5, 34),
    ylim=(5, 12))

In [None]:
def plot_ts_profile_grid(cube, basin_range=(5000, 15000), sill_range=(23500, 30000), shelf_range=(65000, 70000),
                         target_years=[2024], target_months=list(range(1, 13)), xlim=None, ylim=None, ncols=3):
    """
    Grid of T–S diagrams with mean salinity and temperature profiles for basin, sill, and shelf regions.
    One subplot per (year, month). Averages are computed over depth at each region.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar

    plt.rcParams.update({
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "axes.titlesize": 14,
        "axes.labelsize": 14,
        "legend.fontsize": 10,
        "figure.titlesize": 16})

    # Extract available (year, month) pairs
    times = pd.to_datetime([str(t)[:8] for t in cube.transect.values])
    grouped_keys = sorted(set((t.year, t.month) for t in times
                              if t.year in target_years and t.month in target_months))

    n_panels = len(grouped_keys)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 5, nrows * 4), sharex=True, sharey=True)
    axes = axes.flatten()

    # ─── Isopycnal Grid and Labels ─────────────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    for j, (year, month) in enumerate(grouped_keys):
        if j >= len(axes):
            break
        ax = axes[j]
        ax.set_facecolor('grey')

        # Plot isopycnals
        for levels, color, lw in [
            (np.linspace(23, 27, 9), 'black', 0.5),
            ([25.6], 'white', 0.5),
            ([25.7], 'lime', 0.5),
            ([25.8], 'red', 0.5),
            ([25.9], 'blue', 0.5),
            ([26.0], 'black', 0.5),
            ([26.1], 'purple', 0.5),
            ([26.2], 'salmon', 0.5),
            ([26.3], 'yellow', 0.5),
            ([26.4], 'cyan', 0.5)
        ]:
            cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                            colors=color, linewidths=lw, linestyles='--')
            for level in levels:
                if level in manual_labels:
                    ax.clabel(cs, fmt='%1.2f', fontsize=8, inline=False, inline_spacing=0.1,
                              manual=[manual_labels[level]])

        has_data = False
        for (start, end), label, color in [
            (basin_range, 'Basin water (5–15 km)', 'cyan'),
            (sill_range, 'Sill water (23.5–30 km)', 'red'),
            (shelf_range, 'QCS water (65–75 km)', 'salmon'),
        ]:
            subset = cube.sel(along=slice(start, end))
            times_subset = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
            selected_idxs = [i for i, t in enumerate(times_subset)
                             if t.year == year and t.month == month]
            if not selected_idxs:
                continue

            has_data = True
            small = subset.isel(transect=selected_idxs)[['temperature', 'salinity']]
            T = small['temperature'].mean(dim=('transect', 'along')).values
            S = small['salinity'].mean(dim=('transect', 'along')).values
            mask = ~np.isnan(T) & ~np.isnan(S)
            ax.plot(S[mask], T[mask], label=label, color=color)

        if has_data:
            ax.set_title(f"{pd.to_datetime(f'{year}-{month:02}').strftime('%b %Y')}")
            ax.legend(fontsize=8)
        else:
            ax.set_title(f"{pd.to_datetime(f'{year}-{month:02}').strftime('%b %Y')} (no data)")

        if j % ncols == 0:
            ax.set_ylabel(r"$\theta$ (°C)")
        if j // ncols == nrows - 1:
            ax.set_xlabel("Salinity (psu)")

    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    if xlim:
        for ax in axes[:n_panels]: ax.set_xlim(xlim)
    if ylim:
        for ax in axes[:n_panels]: ax.set_ylim(ylim)

    fig.suptitle("Monthly Mean T–S Profiles by Region", fontsize=18)
    plt.tight_layout(rect=[0, 0, 1, 0.95])

plot_ts_profile_grid(
    cube,
    basin_range=(5000, 15000),
    sill_range=(23500, 30000),
    shelf_range=(65000, 75000),
    target_years=[2024, 2025],
    target_months=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],  # or list(range(1, 13)) for all months
    xlim=(29, 34),
    ylim=(5, 16),
    ncols=3)

In [None]:
def plot_ts_profile_grid(cube, basin_range=(5000, 15000), sill_range=(23500, 30000), shelf_range=(65000, 70000),
                         target_years=[2024], target_months=list(range(1, 13)), xlim=None, ylim=None, ncols=3):
    """
    Grid of T–S diagrams with salinity and temperature profiles for basin, sill, and shelf regions.
    One subplot per (year, month). Profiles are shown for each transect separately.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar

    plt.rcParams.update({
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "axes.titlesize": 14,
        "axes.labelsize": 14,
        "legend.fontsize": 10,
        "figure.titlesize": 16})

    times = pd.to_datetime([str(t)[:8] for t in cube.transect.values])
    transect_indices = list(enumerate(times))
    grouped = {}
    for i, t in transect_indices:
        if t.year in target_years and t.month in target_months:
            grouped.setdefault((t.year, t.month), []).append(i)
    grouped_keys = sorted(grouped.keys())

    n_panels = len(grouped_keys)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 5, nrows * 4), sharex=True, sharey=True)
    axes = axes.flatten()

    # ─── Isopycnal Grid and Labels ─────────────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    for j, (year, month) in enumerate(grouped_keys):
        if j >= len(axes):
            break
        ax = axes[j]
        ax.set_facecolor('grey')

        # Plot isopycnals
        for levels, color, lw in [
            (np.linspace(23, 27, 9), 'black', 0.5),
            ([25.6], 'white', 0.5),
            ([25.7], 'lime', 0.5),
            ([25.8], 'red', 0.5),
            ([25.9], 'blue', 0.5),
            ([26.0], 'black', 0.5),
            ([26.1], 'purple', 0.5),
            ([26.2], 'salmon', 0.5),
            ([26.3], 'yellow', 0.5),
            ([26.4], 'cyan', 0.5)
        ]:
            cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                            colors=color, linewidths=lw, linestyles='--')
            for level in levels:
                if level in manual_labels:
                    ax.clabel(cs, fmt='%1.2f', fontsize=8, inline=False, inline_spacing=0.1,
                              manual=[manual_labels[level]])

        has_data = False
        for (start, end), label, color in [
            (basin_range, 'Basin water (5–15 km)', 'cyan'),
            (sill_range, 'Sill water (23.5–30 km)', 'red'),
            (shelf_range, 'QCS water (65–75 km)', 'salmon'),
        ]:
            region = cube.sel(along=slice(start, end))
            for i in grouped[(year, month)]:
                T = region['temperature'].isel(transect=i).mean(dim='along').values
                S = region['salinity'].isel(transect=i).mean(dim='along').values
                mask = ~np.isnan(T) & ~np.isnan(S)
                if np.any(mask):
                    marker_style = '^' if i != grouped[(year, month)][0] else '.'
                    ax.scatter(S[mask], T[mask], label=label if i == grouped[(year, month)][0] else None, color=color, s=5, marker=marker_style)
                    has_data = True

        if has_data:
            ax.set_title(f"{pd.to_datetime(f'{year}-{month:02}').strftime('%b %Y')}")
            ax.legend(fontsize=8)
        else:
            ax.set_title(f"{pd.to_datetime(f'{year}-{month:02}').strftime('%b %Y')} (no data)")

        if j % ncols == 0:
            ax.set_ylabel(r"$\theta$ (°C)")
        if j // ncols == nrows - 1:
            ax.set_xlabel("Salinity (psu)")

    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    if xlim:
        for ax in axes[:n_panels]: ax.set_xlim(xlim)
    if ylim:
        for ax in axes[:n_panels]: ax.set_ylim(ylim)

    fig.suptitle("Monthly T–S Profiles by Region (per transect)", fontsize=18)
    plt.tight_layout(rect=[0, 0, 1, 0.95])

plot_ts_profile_grid(
    cube,
    basin_range=(5000, 15000),
    sill_range=(23500, 30000),
    shelf_range=(65000, 75000),
    target_years=[2024, 2025],
    target_months=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],  # or list(range(1, 13)) for all months
    xlim=(30.5, 34),
    ylim=(5, 16),
    ncols=3)

In [None]:
def plot_ts_profile_by_transect(
    cube, 
    basin_range=(5000, 15000), 
    sill_range=(23500, 30000),
    target_years=[2024], 
    target_months=list(range(1, 13)), 
    xlim=None, 
    ylim=None, 
    ncols=3,
    inset_months=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],  # Toggle insets by month
    inset_xlim=(33, 33.5),  # Zoom box salinity range
    inset_ylim=(6.5, 7.5),   # Zoom box temperature range
    inset_limits_by_date=None
):
    """
    Grid of T–S diagrams (one per transect) with salinity and temperature profiles
    for basin, sill, and shelf regions.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar

    plt.rcParams.update({
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "axes.titlesize": 14,
        "axes.labelsize": 14,
        "legend.fontsize": 10,
        "figure.titlesize": 16})

    # ─── Transect Times ─────────────────────────────
    times = pd.to_datetime([str(t)[:8] for t in cube.transect.values])
    valid_indices = [i for i, t in enumerate(times) if t.year in target_years and t.month in target_months]
    times = times[valid_indices]

    n_panels = len(valid_indices)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 5, nrows * 4), sharex=True, sharey=True)
    axes = axes.flatten()

    # ─── Isopycnal Grid and Labels ─────────────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    for j, i in enumerate(valid_indices):
        if j >= len(axes):
            break
        ax = axes[j]
        ax.set_facecolor('grey')
        target_depth = 130  # m

        # Plot isopycnals
        for levels, color, lw in [
            (np.linspace(23, 27, 9), 'black', 0.5),
            ([25.6], 'white', 0.5),
            ([25.7], 'lime', 0.5),
            ([25.8], 'red', 0.5),
            ([25.9], 'blue', 0.5),
            ([26.0], 'black', 0.5),
            ([26.1], 'purple', 0.5),
            ([26.2], 'salmon', 0.5),
            ([26.3], 'yellow', 0.5),
            ([26.4], 'cyan', 0.5)
        ]:
            cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                            colors=color, linewidths=lw, linestyles='--')
            for level in levels:
                if level in manual_labels:
                    ax.clabel(cs, fmt='%1.2f', fontsize=8, inline=False, inline_spacing=0.1,
                              manual=[manual_labels[level]])

        has_data = False
        for (start, end), label, color in [
            (basin_range, 'Basin (0–15 km)', 'cyan'),
            (sill_range, 'Sill (20–30 km)', 'red'),
        ]:
            region = cube.sel(along=slice(start, end))
            avg_temp = region['temperature'].isel(transect=i).mean(dim='along')
            avg_sali = region['salinity'].isel(transect=i).mean(dim='along')

            T = avg_temp.values
            S = avg_sali.values
            mask = ~np.isnan(T) & ~np.isnan(S)
            if np.any(mask):
                ax.scatter(S[mask], T[mask], label=label, color=color, s=5)
                try:
                    t_130 = avg_temp.sel(depth=target_depth, method='nearest').values
                    s_130 = avg_sali.sel(depth=target_depth, method='nearest').values

                    if not np.isnan(t_130) and not np.isnan(s_130):
                        ax.scatter(s_130, t_130, color = 'black', s = 6)
                        ax.text(s_130, t_130, f'{int(target_depth)} m',
                            color=color, ha='left', va='bottom')
                except KeyError:
                    pass  # depth not found
                has_data = True

        if has_data:
            ax.set_title(times[j].strftime('%d %b %Y'))
            if j % ncols == 0:
                ax.set_ylabel(r"$\theta$ (°C)")
            if j // ncols == nrows - 1:
                ax.set_xlabel("Salinity (psu)")
            # ax.legend(loc='upper left', fontsize=8)

            # ─── Add Inset Only for Selected Months ─────────────────────────────
            if times[j].month in inset_months:
                inset = ax.inset_axes([0.5, 0.5, 0.48, 0.48])
                inset.set_facecolor('grey')

                # ─── Define inset bounds early so they're available for depth check ───
                inset_xlim_local = inset_xlim
                inset_ylim_local = inset_ylim

                # Override inset limits if provided for this date
                if inset_limits_by_date:
                    date_str = times[j].strftime('%Y-%m-%d')
                    if date_str in inset_limits_by_date:
                        inset_xlim_local = inset_limits_by_date[date_str].get('xlim', inset_xlim_local)
                        inset_ylim_local = inset_limits_by_date[date_str].get('ylim', inset_ylim_local)

                # Plot isopycnals in inset
                for levels, color, lw in [
                    (np.linspace(23, 27, 9), 'black', 0.5),
                    ([25.6], 'white', 0.5),
                    ([25.7], 'lime', 0.5),
                    ([25.8], 'red', 0.5),
                    ([25.9], 'blue', 0.5),
                    ([26.0], 'black', 0.5),
                    ([26.1], 'purple', 0.5),
                    ([26.2], 'salmon', 0.5),
                    ([26.3], 'yellow', 0.5),
                    ([26.4], 'cyan', 0.5)
                ]:
                    cs = inset.contour(S_grid, T_grid, sigma, levels=levels,
                                       colors=color, linewidths=lw, linestyles='--')
                    for level in levels:
                        if level in manual_labels:
                            inset.clabel(cs, fmt='%1.2f', fontsize=6, inline=False, inline_spacing=0.1,
                                         manual=[manual_labels[level]])

                # Plot points in inset
                for (start, end), color in [
                    (basin_range, 'cyan'),
                    (sill_range, 'red'),
                ]:
                    region = cube.sel(along=slice(start, end))
                    avg_temp = region['temperature'].isel(transect=i).mean(dim='along')
                    avg_sali = region['salinity'].isel(transect=i).mean(dim='along')
                    T = avg_temp.values
                    S = avg_sali.values
                    mask = ~np.isnan(T) & ~np.isnan(S)
                    if np.any(mask):
                        inset.scatter(S[mask], T[mask], color=color, s=3)
                        try:
                            t_130 = avg_temp.sel(depth=target_depth, method='nearest').values
                            s_130 = avg_sali.sel(depth=target_depth, method='nearest').values

                            if (
                                not np.isnan(t_130) and not np.isnan(s_130) and
                                inset_xlim_local[0] <= s_130 <= inset_xlim_local[1] and
                                inset_ylim_local[0] <= t_130 <= inset_ylim_local[1]
                            ):
                                inset.scatter(s_130, t_130, color = 'black', s = 4)
                                inset.text(s_130, t_130, f'{int(target_depth)} m',
                                        color=color, ha='left', va='bottom')
                        except KeyError:
                            pass  # depth not found

                inset_xlim_local = inset_xlim
                inset_ylim_local = inset_ylim

                # Override inset limits if provided for this date
                if inset_limits_by_date:
                    date_str = times[j].strftime('%Y-%m-%d')
                    if date_str in inset_limits_by_date:
                        inset_xlim_local = inset_limits_by_date[date_str].get('xlim', inset_xlim_local)
                        inset_ylim_local = inset_limits_by_date[date_str].get('ylim', inset_ylim_local)

                inset.set_xlim(*inset_xlim_local)
                inset.set_ylim(*inset_ylim_local)
                inset.tick_params(labelsize=6)
                inset.set_xticks([])
                inset.set_yticks([])

                # ─── Draw dashed black box in main axes ─────────────────────────────
                from matplotlib.patches import Rectangle
                box = Rectangle(
                    (inset_xlim_local[0], inset_ylim_local[0]),
                    inset_xlim_local[1] - inset_xlim_local[0],
                    inset_ylim_local[1] - inset_ylim_local[0],
                    edgecolor='black',
                    facecolor='none',
                    linestyle='--',
                    linewidth=1
                )
                ax.add_patch(box)
    
    #     # ─── Master Legend Items ─────────────────────────────
    legend_labels = ['Basin (0–15 km)', 'Sill (20–30 km)']
    legend_colors = ['cyan', 'red']
    legend_handles = [plt.Line2D([0], [0], marker='o', color='w', label=label,
                                 markerfacecolor=color, markersize=15)
                      for label, color in zip(legend_labels, legend_colors)]

    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    if xlim:
        for ax in axes[:n_panels]: ax.set_xlim(xlim)
    if ylim:
        for ax in axes[:n_panels]: ax.set_ylim(ylim)

    fig.legend(handles=legend_handles, loc='upper right', ncol=1, fontsize=20)
    fig.suptitle("Depth-Averaged T–S Profiles", fontsize=32)
    plt.tight_layout(rect=[0, 0, 1, 0.97])  # Leave room at top for legend

plot_ts_profile_by_transect(
    cube,
    basin_range=(0, 15000),
    sill_range=(20000, 30000),
    target_years=[2024, 2025],
    target_months=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    xlim=(30.5, 34),
    ylim=(5, 16),
    ncols=4,
    inset_months=[4, 5, 6, 7, 8, 9],
    inset_limits_by_date={
        '2024-04-16': {'xlim': (32.4, 33.5), 'ylim': (7.3, 8.5)},
        '2024-05-10': {'xlim': (32.7, 33.5), 'ylim': (7.2, 8.1)},
        '2024-05-16': {'xlim': (32.9, 33.6), 'ylim': (6.8, 7.9)},
        '2024-07-17': {'xlim': (33.3, 33.8), 'ylim': (6, 7.1)},
        '2024-07-23': {'xlim': (33.3, 33.8), 'ylim': (6, 7.1)},
    })

In [None]:
def plot_o2_profile_by_transect(
    cube, 
    basin_range=(5000, 15000), 
    sill_range=(23500, 30000), 
    shelf_range=(65000, 70000),
    target_years=[2024], 
    target_months=list(range(1, 13)), 
    xlim=None, 
    ylim=None, 
    ncols=3):
    """
    Grid of O2–Temperature diagrams (one per transect) with salinity–density isopycnals in the background.
    Skips transects that have no valid O2–temperature data.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar

    plt.rcParams.update({
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "axes.titlesize": 14,
        "axes.labelsize": 14,
        "legend.fontsize": 10,
        "figure.titlesize": 16})

    # ─── Transect Times ─────────────────────────────
    times = pd.to_datetime([str(t)[:8] for t in cube.transect.values])
    valid_indices = [i for i, t in enumerate(times) if t.year in target_years and t.month in target_months]
    times = times[valid_indices]

    # ─── Filter out transects with no valid O2–T data ─────────────────────────────
    plot_indices = []
    plot_times = []
    for i, t in zip(valid_indices, times):
        for start, end in [basin_range, sill_range, shelf_range]:
            region = cube.sel(along=slice(start, end))
            T = region['temperature'].isel(transect=i).mean(dim='along').values
            O2 = region['oxygen_concentration'].isel(transect=i).mean(dim='along').values
            if np.any(~np.isnan(T) & ~np.isnan(O2)):
                plot_indices.append(i)
                plot_times.append(t)
                break

    n_panels = len(plot_indices)
    if n_panels == 0:
        print("No transects with valid O2–temperature data.")
        return

    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 5, nrows * 4), sharex=True, sharey=True)
    axes = axes.flatten()

    # ─── Isopycnal Grid and Labels ─────────────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    for j, (i, t) in enumerate(zip(plot_indices, plot_times)):
        ax = axes[j]
        ax.set_facecolor('grey')

        has_data = False
        for (start, end), label, color in [
            (basin_range, 'Basin (5–15 km)', 'cyan'),
            (sill_range, 'Sill (23.5–30 km)', 'red'),
            (shelf_range, 'QCS (65–75 km)', 'salmon'),
        ]:
            region = cube.sel(along=slice(start, end))
            T = region['temperature'].isel(transect=i).mean(dim='along').values
            O2 = region['oxygen_concentration'].isel(transect=i).mean(dim='along').values
            mask = ~np.isnan(T) & ~np.isnan(O2)
            if np.any(mask):
                ax.scatter(O2[mask], T[mask], label=label, color=color, s=5)
                has_data = True

        if has_data:
            ax.set_title(t.strftime('%d %b %Y'))
            ax.legend(fontsize=8)
        else:
            ax.set_title(f"{t.strftime('%d %b %Y')} (no data)")

        if j % ncols == 0:
            ax.set_ylabel(r"$\theta$ (°C)")
        if j // ncols == nrows - 1:
            ax.set_xlabel("Oxygen (μmol/kg)")

    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    if xlim:
        for ax in axes[:n_panels]:
            ax.set_xlim(xlim)
    if ylim:
        for ax in axes[:n_panels]:
            ax.set_ylim(ylim)

    fig.suptitle("Oxygen–Temperature Profiles by Transect", fontsize=18)
    plt.tight_layout(rect=[0, 0, 1, 0.95])


plot_o2_profile_by_transect(
    cube,
    basin_range=(0, 15000),
    sill_range=(23500, 30000),
    shelf_range=(65000, 75000),
    target_years=[2023],
    target_months=list(range(1, 13)),
    xlim=(0, 350),  # or None for auto
    ylim=(4, 14),
    ncols=4
)

plot_o2_profile_by_transect(
    cube,
    basin_range=(0, 15000),
    sill_range=(23500, 30000),
    shelf_range=(65000, 75000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    xlim=(0, 350),  # or None for auto
    ylim=(4, 14),
    ncols=4
)

In [None]:
def plot_ts(cube, along_values, target_years=[2023, 2024],
            shallowest_depth=None, xlim=None, ylim=None, target_months=None):
    """
    Scatter-style T–S diagram with sigma-theta contours.
    Different along_values get different marker shapes.
    Prints number of valid T–S points per profile.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import itertools
    import gsw

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})

    if isinstance(along_values, (int, float)):
        along_values = [along_values]

    cube = cube.where(cube['along'].isin(along_values), drop=True)

    transects = cube.transect.values
    times = [pd.to_datetime(str(t)[:8]) for t in transects]
    selected_idxs = [i for i, t in enumerate(times)
                     if t.year in target_years and (target_months is None or t.month in target_months)]

    if not selected_idxs:
        print(f"No valid transects for years {target_years} at along = {along_values}")
        return

    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    ax.set_facecolor('grey')

    # Colormap for time
    cmap = plt.colormaps.get_cmap('jet')
    color_map = cmap(np.linspace(0, 1, len(selected_idxs)))

    # One shape per along_value
    marker_map = {
        along_values[0]: 'o',   # circle
        along_values[1]: '',}

    all_S, all_T = [], []

    # Sigma-theta grid
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 200),
                                 np.linspace(5.3, 20, 200))
    sigma = gsw.sigma0(S_grid, T_grid)

    for levels, color, lw in [
            (np.linspace(24, 27, 7), 'black', 0.5),
            ([25.6], 'white', 0.5),
            ([25.7], 'lime', 0.5),
            ([25.8], 'red', 0.5),
            ([25.9], 'blue', 0.5),
            ([26.0], 'black', 0.5),
            ([26.1], 'purple', 0.5),
            ([26.2], 'salmon', 0.5),
            ([26.3], 'yellow', 0.5),
            ([26.4], 'cyan', 0.5)]:
        cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                        colors=color, linewidths=lw, linestyles='--')
        if lw != 0.5:
            ax.clabel(cs, fmt='%1.2f', fontsize=9, inline=True)

    for i, idx in enumerate(selected_idxs):
        this_along_vals = cube['temperature'].coords['along'].values
        date = times[idx]

        for along_val in along_values:
            along_idx = np.argmin(np.abs(this_along_vals - along_val))

            # Corrected to use all depth levels
            T = cube['temperature'][idx, :, along_idx].values.flatten()
            S = cube['salinity'][idx, :, along_idx].values.flatten()

            mask = ~np.isnan(T) & ~np.isnan(S)
            count = np.count_nonzero(mask)

            if count > 0:
                label = f"{date.strftime('%d %b %Y')} @ {int(along_val)} m"
                # print(f"{label}: {count} valid depth points")

                ax.scatter(
                    S[mask], T[mask],
                    color=color_map[i],
                    label=label,
                    marker=marker_map.get(along_val, 'o'),
                    linewidth=0.2)
                all_S.extend(S[mask])
                all_T.extend(T[mask])

    # Axis limits
    if all_S and all_T:
        all_S = np.array(all_S)
        all_T = np.array(all_T)
        if xlim is None:
            xlim = (np.nanmin(all_S) - 0.1, np.nanmax(all_S) + 0.1)
        if ylim is None:
            ylim = (np.nanmin(all_T) - 0.1, np.nanmax(all_T) + 0.1)

    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel(r"$\theta$ (°C)")
    ax.set_title(f"T–S Diagram at {', '.join([str(a)+' m' for a in along_values])}")
    ax.legend(fontsize=9, loc='upper right')
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    plt.tight_layout()

plot_ts(
    cube,
    along_values=[6500, 23500],   # in meters
    target_years=[2024],
    target_months= [3,4,5,6, 7, 8],
    shallowest_depth=0)

In [None]:
def plot_ts_grid(cube, along_values, target_years=[2023, 2024],
                     shallowest_depth=None, xlim=None, ylim=None, target_months=None,
                     ncols=3):
    """
    T–S diagrams with one subplot per (year, month).
    Same marker shape, custom colors for each along-track location.
    Shared x/y labels shown once.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 16,
        'axes.labelsize': 14,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 10,
        'figure.titlesize': 20})

    if isinstance(along_values, (int, float)):
        along_values = [along_values]

    # ─── Filter by station ─────────────────────────────────────
    cube = cube.where(cube['along'].isin(along_values), drop=True)

    # ─── Select time points ────────────────────────────────────
    transects = cube.transect.values
    times = [pd.to_datetime(str(t)[:8]) for t in transects]
    selected = [(i, t) for i, t in enumerate(times)
                if t.year in target_years and (target_months is None or t.month in target_months)]

    if not selected:
        print(f"No valid transects for years {target_years} at along = {along_values}")
        return

    # ─── Group by (year, month) ────────────────────────────────
    grouped = {}
    for i, t in selected:
        ym = (t.year, t.month)
        grouped.setdefault(ym, []).append((i, t))

    sorted_keys = sorted(grouped.keys())
    n_panels = len(sorted_keys)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 5, nrows * 4), sharex=True, sharey=True)
    axes = axes.flatten()

    # ─── Custom colors ─────────────────────────────────────────
    station_colors = {
        6500: 'red',
        23500: 'cyan'
    }
    marker_style = '.'

    # ─── Sigma-theta grid ──────────────────────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 200),
                                 np.linspace(5, 20, 200))
    sigma = gsw.sigma0(S_grid, T_grid)

    all_S, all_T = [], []

    for j, (year, month) in enumerate(sorted_keys):
        ax = axes[j]
        ax.set_facecolor('lightgrey')

        # ─── Density Contours ───────────────────────────────────
        for levels, contour_color, lw in [
                (np.linspace(24, 27, 7), 'black', 0.5),
                ([25.6], 'white', 0.5), ([25.7], 'lime', 0.5), ([25.8], 'red', 0.5),
                ([25.9], 'blue', 0.5), ([26.0], 'black', 0.5), ([26.1], 'purple', 0.5),
                ([26.2], 'salmon', 0.5), ([26.3], 'yellow', 0.5), ([26.4], 'cyan', 0.5)]:
            cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                            colors=contour_color, linewidths=lw, linestyles='--')
            if lw != 0.5:
                ax.clabel(cs, fmt='%1.2f', fontsize=8, inline=True)

        for i, t in grouped[(year, month)]:
            this_along_vals = cube['temperature'].coords['along'].values

            for along_val in along_values:
                along_idx = np.argmin(np.abs(this_along_vals - along_val))
                T = cube['temperature'][i, :, along_idx].values.flatten()
                S = cube['salinity'][i, :, along_idx].values.flatten()
                mask = ~np.isnan(T) & ~np.isnan(S)

                if np.count_nonzero(mask) > 0:
                    color = station_colors.get(along_val, 'gray')
                    ax.scatter(S[mask], T[mask],
                               s=60,
                               color=color,
                               marker=marker_style,
                               linewidth=0.2,
                               label=f"{t.strftime('%d %b')} @ {int(along_val)} m")
                    all_S.extend(S[mask])
                    all_T.extend(T[mask])

        ax.set_title(f"{pd.to_datetime(f'{year}-{month:02}', format='%Y-%m').strftime('%B %Y')}")
        ax.legend(fontsize=8, loc='upper right')

    # ─── Axis Limits ───────────────────────────────────────────
    if all_S and all_T:
        all_S = np.array(all_S)
        all_T = np.array(all_T)
        if xlim is None:
            xlim = (np.nanmin(all_S) - 0.1, np.nanmax(all_S) + 0.1)
        if ylim is None:
            ylim = (np.nanmin(all_T) - 0.1, np.nanmax(all_T) + 0.1)

        for ax in axes:
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)

    # ─── Shared Labels ─────────────────────────────────────────
    fig.text(0.5, 0.01, "Salinity (psu)", ha='center', fontsize=16)
    fig.text(0.01, 0.5, r"$\theta$ (°C)", va='center', rotation='vertical', fontsize=16)

    # ─── Remove unused axes & title ────────────────────────────
    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    fig.suptitle(f"T–S Diagrams at Sill and FHS Deep Basin")
    plt.tight_layout(rect=[0.05, 0.05, 1, 0.95])

plot_ts_grid(
    cube,
    along_values=[6500, 23500],
    target_years=[2024, 2025],
    target_months=[1,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    xlim=(30, 34),
    ylim=(5, 11))

In [None]:

def plot_to_grid(cube, along_values, target_years=[2023, 2024],
                 shallowest_depth=None, xlim=None, ylim=None,
                 target_months=None, ncols=3):
    """
    O2–Temp diagrams with one subplot per (year, month).
    Shared labels, custom color per along-track location.
    Only plots months that contain valid data.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 16,
        'axes.labelsize': 14,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 10,
        'figure.titlesize': 20})

    if isinstance(along_values, (int, float)):
        along_values = [along_values]

    cube = cube.where(cube['along'].isin(along_values), drop=True)

    transects = cube.transect.values
    times = [pd.to_datetime(str(t)[:8]) for t in transects]
    selected = [(i, t) for i, t in enumerate(times)
                if t.year in target_years and (target_months is None or t.month in target_months)]

    if not selected:
        print(f"No valid transects for years {target_years} at along = {along_values}")
        return

    grouped = {}
    for i, t in selected:
        ym = (t.year, t.month)
        grouped.setdefault(ym, []).append((i, t))

    # Filter out empty (year, month) groups
    sorted_keys = []
    for key in sorted(grouped.keys()):
        has_data = False
        for i, _ in grouped[key]:
            for along_val in along_values:
                along_idx = np.argmin(np.abs(cube['along'].values - along_val))
                O2 = cube['oxygen_concentration'][i, :, along_idx].values.flatten()
                T = cube['temperature'][i, :, along_idx].values.flatten()
                mask = ~np.isnan(O2) & ~np.isnan(T)
                if np.count_nonzero(mask) > 0:
                    has_data = True
                    break
            if has_data:
                break
        if has_data:
            sorted_keys.append(key)

    if not sorted_keys:
        print("No valid months with data.")
        return

    n_panels = len(sorted_keys)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 5, nrows * 4), sharex=True, sharey=True)
    axes = axes.flatten()

    station_colors = {
        6500: 'red',
        23500: 'cyan'
    }
    marker_style = '.'

    all_O2, all_T = [], []

    for j, (year, month) in enumerate(sorted_keys):
        ax = axes[j]
        ax.set_facecolor('lightgrey')

        for i, t in grouped[(year, month)]:
            this_along_vals = cube['along'].values

            for along_val in along_values:
                along_idx = np.argmin(np.abs(this_along_vals - along_val))

                O2 = cube['oxygen_concentration'][i, :, along_idx].values.flatten()
                T = cube['temperature'][i, :, along_idx].values.flatten()
                mask = ~np.isnan(O2) & ~np.isnan(T)

                if np.count_nonzero(mask) > 0:
                    color = station_colors.get(along_val, 'gray')
                    ax.scatter(O2[mask], T[mask],
                               s=60,
                               color=color,
                               marker=marker_style,
                               linewidth=0.2,
                               label=f"{t.strftime('%d %b')} @ {int(along_val)} m")
                    all_O2.extend(O2[mask])
                    all_T.extend(T[mask])

        ax.set_title(f"{pd.to_datetime(f'{year}-{month:02}', format='%Y-%m').strftime('%B %Y')}")
        ax.legend(fontsize=8, loc='upper right')

    if all_O2 and all_T:
        if xlim is not None:
            for ax in axes:
                ax.set_xlim(xlim)
        if ylim is not None:
            for ax in axes:
                ax.set_ylim(ylim)

    fig.text(0.5, 0.01, "Oxygen (μmol/kg)", ha='center', fontsize=16)
    fig.text(0.01, 0.5, r"$\theta$ (°C)", va='center', rotation='vertical', fontsize=16)

    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    fig.suptitle(f"Oxygen–Temperature Diagrams at Sill and FHS Deep Basin")
    plt.tight_layout(rect=[0.05, 0.05, 1, 0.95])

plot_to_grid(
    cube,
    along_values=[6500, 23500],
    target_years=[2024, 2025],
    target_months=[1,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    # xlim=(30, 34),
    ylim=(5, 11))

In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np
import os
import cmocean as cm
import waypoint_distance as wd
import pandas as pd
from pathlib import Path
from datetime import datetime
from matplotlib.dates import DateFormatter
import gsw
import matplotlib.dates as mdates
%matplotlib widget

cube = xr.open_dataset(os.path.expanduser('~/Desktop/Summer 2025 Python/calvert_cube.nc'))
topo = xr.open_dataset(os.path.expanduser('~/Desktop/Summer 2025 Python/british_columbia_3_msl_2013.nc'))
ds = xr.open_dataset('~/Desktop/Summer 2025 Python/Hakai_calvert.nc')

def plot_o2_density_by_transect(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    xlim=(0, 350),
    ylim=(27, 23),  # High density at bottom
    ncols=4,
):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    from matplotlib.colors import Normalize
    import os

    plt.rcParams.update({
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "axes.titlesize": 14,
        "axes.labelsize": 12,
        "legend.fontsize": 10,
        "figure.titlesize": 16})

    cube = cube.load()
    subset = cube.sel(along=slice(*region_range))
    times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
    selected_idxs = [i for i, t in enumerate(times)
                     if t.year in target_years and t.month in target_months]

    # ─── Filter only transects with valid data ─────────────────────────────
    valid_idxs = []
    for idx in selected_idxs:
        ds = subset.isel(transect=idx)
        o2 = ds['oxygen_concentration'].values
        sigma = ds['potential_density'].values - 1000
        if np.any(~np.isnan(o2) & ~np.isnan(sigma)):
            valid_idxs.append(idx)

    if not valid_idxs:
        print("No valid transects with data.")
        return

    n_panels = len(valid_idxs)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
                             figsize=(4.5 * ncols, 4 * nrows), squeeze=False)
    axes = axes.ravel()

    cmap = plt.colormaps['jet']
    norm = Normalize(vmin=region_range[0], vmax=region_range[1])

    for i, idx in enumerate(valid_idxs):
        ax = axes[i]
        ax.set_facecolor('grey')

        ds = subset.isel(transect=idx)
        o2 = ds['oxygen_concentration'].values
        sigma = ds['potential_density'].values - 1000
        along_vals = ds['along'].values

        for a_idx, along in reversed(list(enumerate(along_vals))):
            o2_col = o2[:, a_idx]
            sigma_col = sigma[:, a_idx]
            mask = ~np.isnan(o2_col) & ~np.isnan(sigma_col)
            if np.any(mask):
                ax.scatter(o2_col[mask], sigma_col[mask], s=0.25, color=cmap(norm(along)), marker='o')

        ax.set_title(times[idx].strftime('%d %b %Y'))

        if i % ncols == 0:
            ax.set_ylabel("σ₀ (kg/m³)")
        else:
            ax.set_yticklabels([])

        if i // ncols == nrows - 1:
            ax.set_xlabel("Oxygen (µmol/L)")
        else:
            ax.set_xticklabels([])

        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.grid(color = 'black', linestyle = '--', linewidth = 0.5, alpha = 0.7)

    # ─── Remove unused axes ─────────────────────────────
    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    fig.subplots_adjust(right=0.88, top=0.92)
    cbar_ax = fig.add_axes([0.9, 0.1, 0.02, 0.8])
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax)
    cbar.set_label('Along-transect distance (m)', fontsize=12)

    fig.suptitle("O₂–σ₀", fontsize=24)

    for ax in fig.axes:
        ax.set_rasterization_zorder(1)

plot_o2_density_by_transect(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    xlim=(0, 350),
    ylim=(27, 23),
    ncols=3
)

In [None]:
def plot_o2_density_histogram_by_transect(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    dens_bins=np.linspace(23, 27, 500),
    o2_bins=np.linspace(0, 350, 500),
    ncols=3
):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    from matplotlib.colors import LogNorm
    from matplotlib.cm import ScalarMappable
    from matplotlib.colors import Normalize

    cube = cube.load()
    subset = cube.sel(along=slice(*region_range))
    times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
    selected_idxs = [i for i, t in enumerate(times)
                     if t.year in target_years and t.month in target_months]

    # ─── Filter valid transects with oxygen data ───────────────────────
    valid_idxs = []
    for idx in selected_idxs:
        ds = subset.isel(transect=idx)
        if np.any(~np.isnan(ds['oxygen_concentration'].values)):
            valid_idxs.append(idx)

    if not valid_idxs:
        print("No valid transects with O₂ data.")
        return

    n_panels = len(valid_idxs)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
                             figsize=(4.5 * ncols, 4 * nrows), squeeze=False)
    axes = axes.ravel()

    for i, idx in enumerate(valid_idxs):
        ax = axes[i]
        ax.set_facecolor('grey')
        ds = subset.isel(transect=idx)
        o2 = ds['oxygen_concentration'].values
        sigma = ds['potential_density'].values - 1000
        mask = ~np.isnan(o2) & ~np.isnan(sigma)
        if not np.any(mask):
            ax.axis('off')
            continue

        o2_vals = o2[mask]
        sigma_vals = sigma[mask]
        hist, xedges, yedges = np.histogram2d(o2_vals, sigma_vals, bins=[o2_bins, dens_bins])

        pcm = ax.pcolormesh(xedges, yedges, hist.T, cmap='inferno', norm=LogNorm(vmin=1, vmax=1000))

        ax.set_title(times[idx].strftime('%d %b %Y'), fontsize=10)
        ax.set_ylim(27, 23)  # Density from high to low

        if i % ncols == 0:
            ax.set_ylabel("σ₀ (kg/m³ - 1000)")
        else:
            ax.set_yticklabels([])

        if i // ncols == nrows - 1:
            ax.set_xlabel("Oxygen (µmol/L)")
        else:
            ax.set_xticklabels([])

        ax.grid(True)

    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])
    counts = []

    for idx in selected_idxs:
        ds_transect = subset.isel(transect=idx)
        o2 = ds_transect['oxygen_concentration'].values
        sigma = ds_transect['potential_density'].values - 1000

        mask = ~np.isnan(o2) & ~np.isnan(sigma)
        count = np.sum(mask)
        counts.append(count)

    # Show results
    for idx, count in zip(selected_idxs, counts):
        print(f"Transect {idx} — {count} valid points")

    print(f"Total across all transects: {np.sum(counts)}")
    fig.subplots_adjust(right=0.88, top=0.92)
    cbar_ax = fig.add_axes([0.9, 0.1, 0.02, 0.8])
    sm = ScalarMappable(cmap='inferno', norm=LogNorm(vmin=1, vmax=1000))
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax)
    cbar.set_label('Log-scaled Count')

    fig.suptitle("O₂–σ₀ Histograms 0 to 30km", fontsize=20)
    plt.show()

plot_o2_density_histogram_by_transect(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    dens_bins=np.linspace(23, 27, 200),
    o2_bins=np.linspace(0, 350, 200),
    ncols=3
)

In [None]:
def plot_o2_density_histogram_all_transects(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    dens_bins=np.arange(23, 27.01, 0.01),
    o2_bins=np.arange(0, 350.01, 1),
):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    from matplotlib.colors import LogNorm
    import gsw
    import os

    plt.rcParams.update({
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "axes.titlesize": 16,
        "axes.labelsize": 14,
        "legend.fontsize": 12,
        "figure.titlesize": 18})

    cube = cube.load()
    subset = cube.sel(along=slice(*region_range))
    times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
    selected_idxs = [i for i, t in enumerate(times)
                     if t.year in target_years]

    all_o2 = []
    all_sigma = []

    for idx in selected_idxs:
        ds = subset.isel(transect=idx)
        o2 = ds['oxygen_concentration'].values
        sigma = ds['potential_density'].values - 1000  # Convert to σ₀

        mask = ~np.isnan(o2) & ~np.isnan(sigma)
        all_o2.append(o2[mask])
        all_sigma.append(sigma[mask])

    if not all_o2 or not all_sigma:
        print("No valid O₂/σ₀ data.")
        return

    O2 = np.concatenate(all_o2)
    SIGMA = np.concatenate(all_sigma)
    print(f"Total valid O₂–σ₀ points: {len(O2)}")

    hist, xedges, yedges = np.histogram2d(O2, SIGMA, bins=[o2_bins, dens_bins])

    fig, ax = plt.subplots(figsize=(8, 6))
    pcm = ax.pcolormesh(xedges, yedges, hist.T, cmap='inferno',
                        norm=LogNorm(vmin=1, vmax=7500), rasterized=True)

    ax.set_xlabel("Oxygen (µmol/L)")
    ax.set_ylabel("σ₀ (kg/m³ - 1000)")
    ax.set_title("O₂–σ₀ Histogram (2024–2025, 0–30 km)")
    ax.set_facecolor('grey')
    ax.set_ylim(dens_bins[-1], dens_bins[0])  # Flip y-axis (high to low density)

    cbar = fig.colorbar(pcm, ax=ax)
    cbar.set_label("Log-scaled Count")

    fig.tight_layout()
    # output_path = os.path.expanduser("~/Desktop/o2_density_hist_all.pdf")
    # fig.savefig(output_path, format='pdf', dpi=600, bbox_inches='tight')

    plt.show()
    print(f"Max bin count: {hist.max()}")

plot_o2_density_histogram_all_transects(
    cube,
    region_range=(0, 30000),
    target_years=[2020, 2021, 2022, 2023, 2024, 2025],
    dens_bins=np.arange(23, 27.01, 0.01),
    o2_bins=np.arange(0, 400.01, 1),
)

In [None]:
def plot_o2_sigma_profile_by_transect(
    cube, 
    basin_range=(5000, 15000), 
    sill_range=(23500, 30000),
    shelf_range=(60000, 75000),
    target_years=[2024, 2025], 
    target_months=list(range(1, 13)), 
    xlim=(0, 350), 
    ylim=(27, 23),
    ncols=4
):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd

    plt.rcParams.update({
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "axes.titlesize": 14,
        "axes.labelsize": 14,
        "legend.fontsize": 10,
        "figure.titlesize": 16})

    # ─── Filter Transects ─────────────────────────────
    times_all = pd.to_datetime([str(t)[:8] for t in cube.transect.values])
    valid_data = []

    for i, t in enumerate(times_all):
        if t.year in target_years and t.month in target_months:
            has_data = False
            for start, end in [basin_range, sill_range, shelf_range]:
                region = cube.sel(along=slice(start, end))
                o2 = region['oxygen_concentration'].isel(transect=i).mean(dim='along')
                sigma = region['potential_density'].isel(transect=i).mean(dim='along') - 1000
                if np.any(~np.isnan(o2.values) & ~np.isnan(sigma.values)):
                    has_data = True
                    break
            if has_data:
                valid_data.append((i, t))

    if not valid_data:
        print("No transects with valid data.")
        return

    n_panels = len(valid_data)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 5, nrows * 4), sharex=True, sharey=True)
    axes = axes.flatten()


    for j, (i, t) in enumerate(valid_data):
        ax = axes[j]
        ax.set_facecolor('grey')
        ax.grid()

        for (start, end), label, color in [
            (basin_range, 'Basin (0–15 km)', 'cyan'),
            (sill_range, 'Sill (20–30 km)', 'red'),
            (shelf_range, 'Shelf (60–75 km)', 'orange'),
        ]:
            region = cube.sel(along=slice(start, end))
            o2 = region['oxygen_concentration'].isel(transect=i).mean(dim='along')
            sigma = region['potential_density'].isel(transect=i).mean(dim='along') - 1000

            mask = ~np.isnan(o2.values) & ~np.isnan(sigma.values)
            if np.any(mask):
                ax.scatter(o2.values[mask], sigma.values[mask], label=label, color=color, s=5)

        ax.set_title(t.strftime('%d %b %Y'))
        if j % ncols == 0:
            ax.set_ylabel("σ₀ (kg/m³ - 1000)")
        if j // ncols == nrows - 1:
            ax.set_xlabel("Oxygen (µmol/L)")
        ax.set_xlim(*xlim)
        ax.set_ylim(*ylim)

    # ─── Remove unused axes ─────────────────────────────
    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    legend_labels = ['Basin (0–15 km)', 'Sill (20–30 km)', 'Shelf (60–75 km)']
    legend_colors = ['cyan', 'red', 'orange']
    legend_handles = [plt.Line2D([0], [0], marker='o', color='w', label=label,
                                 markerfacecolor=color, markersize=10)
                      for label, color in zip(legend_labels, legend_colors)]

    fig.legend(handles=legend_handles, loc='upper right', fontsize=14)
    fig.suptitle("Depth-Averaged O₂–σ₀ Profiles by Transect", fontsize=22)
    plt.tight_layout(rect=[0, 0, 1, 0.95])

plot_o2_sigma_profile_by_transect(
    cube,
    basin_range=(0, 15000),
    sill_range=(20000, 30000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    xlim=(0, 350),
    ylim=(27, 23),
    ncols=3
)

In [None]:
def compare_ts_by_month_markers(cube, basin_range=(5000, 15000), sill_range=(23500, 30000)):
    import matplotlib.pyplot as plt
    import gsw
    import pandas as pd
    import numpy as np
    from scipy.optimize import root_scalar
    from matplotlib.lines import Line2D

    # ─── Setup ─────────────────────────────
    target_year = 2024
    target_months = [4, 5, 6, 7]
    transects = cube.transect.values
    times = pd.to_datetime([str(t)[:8] for t in transects])
    valid = [(i, t) for i, t in enumerate(times) if t.year == target_year and t.month in target_months]

    # ─── Isopycnals ─────────────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    isopycnal_levels = [25.6, 25.7, 25.8, 25.9, 26.0, 26.1, 26.2]
    T_label = 6.0
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    # ─── Plot ─────────────────────────────
    fig, ax = plt.subplots(figsize=(18, 12))
    ax.set_facecolor('grey')

    # Draw isopycnals
    for level in isopycnal_levels:
        cs = ax.contour(S_grid, T_grid, sigma, levels=[level], colors='black', linewidths=0.5, linestyles='--')
        if level in manual_labels:
            ax.clabel(cs, fmt='%1.2f', fontsize=8, inline=False, inline_spacing=0.1,
                      manual=[manual_labels[level]])

    # Colors and markers
    region_colors = {'Basin': 'cyan', 'Sill': 'red'}
    month_markers = {'Apr': 'o', 'May': 's', 'Jun': '^', 'Jul': 'D'}

    # Plot T–S points
    for i, t in valid:
        month_str = t.strftime('%b')
        ds = cube.isel(transect=i)

        for label, (start, end) in zip(['Basin', 'Sill'], [basin_range, sill_range]):
            region = ds.sel(along=slice(start, end))
            T = region['temperature'].mean(dim='along').values
            S = region['salinity'].mean(dim='along').values
            mask = ~np.isnan(T) & ~np.isnan(S)
            if np.any(mask):
                ax.scatter(S[mask], T[mask],
                           color=region_colors[label],
                           marker=month_markers[month_str],
                           s=50,
                           label=f"{label} – {month_str}")

    # Format
    ax.set_xlim(30.5, 34)
    ax.set_ylim(5, 16)
    ax.set_xlabel("Salinity (psu)", fontsize=20)
    ax.set_ylabel("Potential Temperature (°C)", fontsize=20)

    # Legend
    handles = []
    for month in ['Apr', 'May', 'Jun', 'Jul']:
        for label in ['Basin', 'Sill']:
            handles.append(Line2D([0], [0], marker=month_markers[month], color='w',
                                  markerfacecolor=region_colors[label], markersize=10,
                                  label=f"{label} – {month}"))

    ax.legend(handles=handles, fontsize=14, loc='lower right', ncol=2)

    # Final layout
    plt.tight_layout(rect=[0.05, 0.05, 0.95, 0.95])
    fig.suptitle("TS Profiles by Region and Month (2024)", fontsize=26)

# Call it
compare_ts_by_month_markers(cube)

In [None]:
cube

In [None]:
def plot_chlorophyll_section_grid(cube, topo, target_years, ncols=3, xlim=(77, 0), vmin=0, vmax=5):
    """
    Plot a grid of chlorophyll sections from a transect cube, filtered by year.
    Includes isopycnals and bathymetry shading.

    Parameters:
    - cube: xarray.Dataset with dimensions (transect, depth, along)
    - topo: bathymetry dataset
    - target_years: list of years to include
    - ncols: number of columns in the grid
    - xlim: x-axis limits for each subplot
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import cmocean
    import xarray as xr
    from datetime import datetime

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})

    # ─── Filter transects with valid chlorophyll data ─────────────────────────────
    valid_transects = []
    for t in cube.transect.values:
        if int(str(t)[:4]) in target_years:
            chl = cube.sel(transect=t)['chlorophyll']
            if np.any(~np.isnan(chl.values)):
                valid_transects.append(str(t))

    cube_sel = cube.sel(transect=valid_transects)
    times = [datetime.strptime(t[:8], "%Y%m%d") for t in valid_transects]

    n = len(valid_transects)
    nrows = int(np.ceil(n / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 1.5 * 6.4, nrows * 1 * 4.8), sharex=True, sharey=True)
    axes = axes.flatten()

    for i, tran_name in enumerate(valid_transects):
        ds = cube_sel.sel(transect=tran_name)
        ax = axes[i]

        chl = ds['chlorophyll'].values
        pdens = ds['potential_density'].values - 1000
        depth = ds['depth'].values
        along = ds['along'].values
        lon = ds['longitude'].values
        lat = ds['latitude'].values

        interp_bathy = topo['Band1'].interp(
            lon=xr.DataArray(lon, dims='along'),
            lat=xr.DataArray(lat, dims='along'),
            method='nearest')
        ocean_floor = -interp_bathy.values

        ax.fill_between(along / 1000, ocean_floor, 500,
                        where=~np.isnan(ocean_floor), facecolor='grey', zorder=1)

        pc = ax.pcolormesh(along / 1000, depth, chl,
                           shading='auto', cmap='viridis',
                           vmin=vmin, vmax=vmax, zorder=2, rasterized=True)

        ax.plot(along / 1000, ocean_floor, color='black', linewidth=1.5)

        iso = ax.contour(along / 1000, depth, pdens, colors='black', levels=np.arange(24, 27.5, 0.5), linewidths=0.5)
        ax.clabel(iso, fmt='%1.2f', fontsize=6)

        date_str = pd.to_datetime(tran_name[:8], format='%Y%m%d').strftime('%B %d')
        leg = 'Out' if 'out' in tran_name else 'Return'
        ax.set_title(f"{date_str} ({leg})", pad=2)

        # Top axis time labels
        if 'time_top' in ds:
            along_km = along / 1000
            time_top = ds['time_top'].values
            nticks = 8
            idx_ticks = np.linspace(0, len(along_km) - 1, nticks, dtype=int)
            tick_locs = along_km[idx_ticks]
            tick_times = time_top[idx_ticks]
            valid_mask = ~pd.isnull(tick_times)
            tick_locs = tick_locs[valid_mask]
            tick_labels = [pd.to_datetime(t).strftime('%b %d %H:%M') for t in tick_times[valid_mask]]

            ax_top = ax.secondary_xaxis('top')
            ax_top.set_xticks(tick_locs)
            ax_top.set_xticklabels(tick_labels, rotation=30, ha='center', fontsize=8)

    ax.invert_yaxis()
    ax.invert_xaxis()
    ax.set_xlim(xlim)
    ax.set_ylim(440, 0)

    for j in range(i + 1, len(axes)):
        fig.delaxes(axes[j])

    plt.tight_layout(rect=[0.03, 0.04, 0.85, 0.94])
    cbar_ax = fig.add_axes([0.87, 0.25, 0.015, 0.5])
    cbar = fig.colorbar(pc, cax=cbar_ax)
    cbar.ax.tick_params(labelsize=12)
    cbar.set_label("Chlorophyll (mg/m³)", fontsize=30)

    fig.suptitle(f"Chlorophyll Sections ({', '.join(map(str, target_years))})", fontsize=40)
    x_center = (0.03 + 0.85) / 2
    fig.text(x_center, 0.01, 'Distance along transect (km)', ha='center', va='center', fontsize=30)
    fig.text(0.012, 0.5, 'Depth (m)', ha='center', va='center', rotation='vertical', fontsize=30)

plot_chlorophyll_section_grid(
    cube,
    topo,
    target_years=[2024, 2025],
    ncols=4,
    xlim=(77, 0),
    vmin=0,
    vmax=2
)

In [None]:
plot_chlorophyll_section_grid(
    cube,
    topo,
    target_years=[2024, 2025],
    ncols=4,
    xlim=(77, 0),
    vmin=0,
    vmax=1
)

In [None]:
def plot_backscatter_section_grid(
    cube, topo, target_years, ncols=3, xlim=(77, 0), vmin=0, vmax=2
):
    """
    Plot a grid of backscatter_700 sections from a transect cube, filtered by year.
    Includes isopycnals and bathymetry shading.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import cmocean
    import xarray as xr
    from datetime import datetime

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})

    # ─── Filter transects with valid backscatter data ───────────────
    valid_transects = []
    for t in cube.transect.values:
        if int(str(t)[:4]) in target_years:
            bs = cube.sel(transect=t)['backscatter_700']
            if np.any(~np.isnan(bs.values)):
                valid_transects.append(str(t))

    cube_sel = cube.sel(transect=valid_transects)
    times = [datetime.strptime(t[:8], "%Y%m%d") for t in valid_transects]

    n = len(valid_transects)
    nrows = int(np.ceil(n / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 1.5 * 6.4, nrows * 1 * 4.8), sharex=True, sharey=True)
    axes = axes.flatten()

    for i, tran_name in enumerate(valid_transects):
        ds = cube_sel.sel(transect=tran_name)
        ax = axes[i]

        bs = ds['backscatter_700'].values
        pdens = ds['potential_density'].values - 1000
        depth = ds['depth'].values
        along = ds['along'].values
        lon = ds['longitude'].values
        lat = ds['latitude'].values

        interp_bathy = topo['Band1'].interp(
            lon=xr.DataArray(lon, dims='along'),
            lat=xr.DataArray(lat, dims='along'),
            method='nearest')
        ocean_floor = -interp_bathy.values

        ax.fill_between(along / 1000, ocean_floor, 500,
                        where=~np.isnan(ocean_floor), facecolor='grey', zorder=1)

        pc = ax.pcolormesh(along / 1000, depth, bs,
                           shading='auto', cmap='magma',
                           vmin=vmin, vmax=vmax, zorder=2, rasterized=True)

        ax.plot(along / 1000, ocean_floor, color='black', linewidth=1.5)

        iso = ax.contour(along / 1000, depth, pdens, colors='white', levels=np.arange(24, 27.5, 0.5), linewidths=0.5)
        ax.clabel(iso, fmt='%1.2f', fontsize=6)

        date_str = pd.to_datetime(tran_name[:8], format='%Y%m%d').strftime('%B %d')
        leg = 'Out' if 'out' in tran_name else 'Return'
        ax.set_title(f"{date_str} ({leg})", pad=2)

    ax.invert_yaxis()
    ax.invert_xaxis()
    ax.set_xlim(xlim)
    ax.set_ylim(440, 0)

    for j in range(i + 1, len(axes)):
        fig.delaxes(axes[j])

    plt.tight_layout(rect=[0.03, 0.04, 0.85, 0.94])
    cbar_ax = fig.add_axes([0.87, 0.25, 0.015, 0.5])
    cbar = fig.colorbar(pc, cax=cbar_ax)
    cbar.ax.tick_params(labelsize=12)
    cbar.set_label("Backscatter (700 nm)", fontsize=30)

    fig.suptitle(f"Backscatter_700 Sections ({', '.join(map(str, target_years))})", fontsize=40)
    x_center = (0.03 + 0.85) / 2
    fig.text(x_center, 0.01, 'Distance along transect (km)', ha='center', va='center', fontsize=30)
    fig.text(0.012, 0.5, 'Depth (m)', ha='center', va='center', rotation='vertical', fontsize=30)

plot_backscatter_section_grid(
    cube,
    topo,
    target_years=[2024, 2025],
    ncols=4,
    xlim=(77, 0),
    vmin=0,
    vmax=0.0025
)

In [None]:
plot_backscatter_section_grid(
    cube,
    topo,
    target_years=[2024, 2025],
    ncols=4,
    xlim=(77, 0),
    vmin=0,
    vmax=0.001
)

In [None]:
cube['backscatter_700'].plot()

In [None]:
import numpy as np

# Assuming your dataset is called cube and variable is 'backscatter_700'
backscatter_data = cube['backscatter_700'].values

# Mask NaNs and flatten for min/max
valid_bs = backscatter_data[~np.isnan(backscatter_data)]

bs_min = np.nanmin(valid_bs)
bs_max = np.nanmax(valid_bs)
bs_mean = np.nanmean(valid_bs)

print(f"Backscatter_700 range: min = {bs_min:.3f}, max = {bs_max:.3f}, mean = {bs_mean:.7f}")

In [None]:
cube