In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np
import os
import cmocean as cm
import waypoint_distance as wd
import pandas as pd
from pathlib import Path
from datetime import datetime
from matplotlib.dates import DateFormatter
import gsw
import matplotlib.dates as mdates
%matplotlib widget

cube = xr.open_dataset(os.path.expanduser('~/Desktop/Summer 2025 Python/calvert_cube.nc'))
topo = xr.open_dataset(os.path.expanduser('~/Desktop/Summer 2025 Python/british_columbia_3_msl_2013.nc'))
ds = xr.open_dataset('~/Desktop/Summer 2025 Python/Hakai_calvert.nc')

def plot_ts_from_station(ds, station_id, years=[2023, 2024, 2025],
                         depth_range=(100, 450), xlim=None, ylim=None, target_months=None):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    import itertools
    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})

    # Filter and clean
    ds_station = ds.where(ds['station'] == station_id, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)
    if target_months is not None:
        ds_station = ds_station.where(ds_station['time'].dt.month.isin(target_months), drop=True)

    if ds_station.sizes['row'] < 2:
        print(f"No valid data for station {station_id} in years {years}.")
        return

    # ─── Interpolate profiles onto a uniform depth grid ───
    regular_depth = np.arange(0, 423, 1)
    sigma_theta = gsw.sigma0(ds_station['salinity'], ds_station['temperature'])
    times = np.unique(ds_station['time'].values)

    all_S, all_T = [], []
    cmap = plt.colormaps.get_cmap('jet')
    colors = cmap(np.linspace(0, 1, len(times)))
    marker = itertools.cycle(['o', 'D', 's', 'X', 'P', '^', 'v', '<', '>', '*', 'h'])

    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    ax.set_facecolor('lightgrey')

    # Sigma-theta grid
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 200),
                                 np.linspace(5.3, 20, 200))
    sigma = gsw.sigma0(S_grid, T_grid)

    for levels, color, lw in [
            (np.linspace(24, 27, 7), 'black', .51),
                        ([25.6], 'white', 0.5),
                        ([25.7], 'lime', 0.5),
                        ([25.8], 'red', 0.5),
                        ([25.9], 'blue', 0.5),
                        ([26.0], 'black', 0.51),
                        ([26.1], 'purple', 0.5),
                        ([26.2], 'salmon', 0.5),
                        ([26.3], 'yellow', 0.5),
                        ([26.4], 'cyan', 0.5)]:
        cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                        colors=color, linewidths=lw, linestyles='--')
        if lw != 0.5:
            ax.clabel(cs, fmt='%1.2f', fontsize=9, inline=True)

    # Plot each profile
    for i, t in enumerate(times):
        profile = ds_station.where(ds_station['time'] == t, drop=True)
        if profile.sizes['row'] < 2:
            continue

        depth = profile['depth'].values
        temp = profile['temperature'].values
        sal = profile['salinity'].values

        mask = (depth >= depth_range[0]) & (depth <= depth_range[1])
        temp = temp[mask]
        sal = sal[mask]

        temp_masked = ~np.isnan(temp) & ~np.isnan(sal)
        if np.any(temp_masked):
            ax.scatter(sal[temp_masked], temp[temp_masked],
                       color=colors[i],
                       s=90,
                       marker=next(marker),
                       linewidth=0.2,
                       label=pd.Timestamp(t).strftime('%d %b %Y'))

            all_S.extend(sal[temp_masked])
            all_T.extend(temp[temp_masked])

    # Axis limits
    if all_S and all_T:
        all_S = np.array(all_S)
        all_T = np.array(all_T)
        if xlim is None:
            xlim = (np.nanmin(all_S) - 0.1, np.nanmax(all_S) + 0.1)
        if ylim is None:
            ylim = (np.nanmin(all_T) - 0.1, np.nanmax(all_T) + 0.1)

    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel(r"$\theta$ (°C)")
    ax.set_title(f"T–S Diagram at {station_id}")
    ax.legend(fontsize=9, loc='upper right')
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    plt.tight_layout()

In [None]:
def plot_ts_shelf_basin(cube, basin_range=(5000, 15000), sill_range=(23500, 30000), shelf_range=(65000, 70000),
                        target_years=[2024], target_months=[5], which_date='first', xlim=None, ylim=None):
    """
    T–S diagram contrasting basin vs sill vs shelf water.
    Uses the first or second transect in the selected month per region.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar

    assert which_date in ['first', 'second'], "which_date must be 'first' or 'second'"

    plt.rcParams.update({
        "xtick.labelsize": 30,
        "ytick.labelsize": 30,
        "axes.titlesize": 30,
        "axes.labelsize": 30,
        "legend.fontsize": 20,
        "figure.titlesize": 40})

    cube = cube.load()
    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    ax.set_facecolor('grey')

    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)

    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    for levels, color, lw in [
        (np.linspace(23, 27, 9), 'black', 0.5),
        ([25.6], 'white', 0.5),
        ([25.7], 'lime', 0.5),
        ([25.8], 'red', 0.5),
        ([25.9], 'blue', 0.5),
        ([26.0], 'black', 0.5),
        ([26.1], 'purple', 0.5),
        ([26.2], 'salmon', 0.5),
        ([26.3], 'yellow', 0.5),
        ([26.4], 'cyan', 0.5)
    ]:
        cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                        colors=color, linewidths=lw, linestyles='--')
        for level in levels:
            if level in manual_labels:
                ax.clabel(cs, fmt='%1.2f', fontsize=10, inline=True,
                          manual=[manual_labels[level]])

    for (start, end), label, color, marker in [
        ((basin_range), 'Basin water (5–15 km)', 'cyan', 'o'),
        ((sill_range), 'Sill water (23.5–30 km)', 'red', 'o'),
        ((shelf_range), 'QCS water (65–75 km)', 'salmon', 'o'),
    ]:
        subset = cube.sel(along=slice(start, end))
        times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
        selected_idxs = [i for i, t in enumerate(times)
                         if t.year in target_years and t.month in target_months]

        if selected_idxs:
            selected_idxs.sort()
            idx = selected_idxs[0] if which_date == 'first' else selected_idxs[1] if len(selected_idxs) > 1 else selected_idxs[0]
            small = subset.isel(transect=idx)[['temperature', 'salinity']]
            small.load()

            T = small['temperature'].values.ravel()
            S = small['salinity'].values.ravel()
            mask = ~np.isnan(T) & ~np.isnan(S)
            ax.scatter(S[mask], T[mask], s=20, color=color, label=label, marker=marker)

    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel(r"$\theta$ (°C)")
    ax.set_title(f"T–S Diagram: Basin vs Sill vs Shelf Water ({which_date.capitalize()} transect)")
    ax.legend()
    if xlim: ax.set_xlim(xlim)
    if ylim: ax.set_ylim(ylim)
    plt.tight_layout()

plot_ts_shelf_basin(
    cube,
    basin_range=(5000, 15000),
    sill_range=(23500, 30000),
    shelf_range=(65000, 75000),
    target_years=[2024],
    target_months=[5],
    which_date='first',
    xlim=(30, 34),
    ylim=(5, 12)
)

plot_ts_shelf_basin(
    cube,
    basin_range=(5000, 15000),
    sill_range=(23500, 30000),
    shelf_range=(65000, 75000),
    target_years=[2024],
    target_months=[5],
    which_date='second',
    xlim=(30, 34),
    ylim=(5, 12)
)

# Some crazy thinking that may lead to cool results indicating mixing? #

In [None]:
def plot_ts_shelf_basin(cube, region_range=(0, 30000),
                        target_years=[2024], target_months=[5], which_date='first',
                        xlim=None, ylim=None):
    """
    T–S diagram with individual profiles colored by along-track distance using a colormap.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar
    from matplotlib.cm import get_cmap
    from matplotlib.colors import Normalize

    assert which_date in ['first', 'second'], "which_date must be 'first' or 'second'"

    plt.rcParams.update({
        "xtick.labelsize": 30,
        "ytick.labelsize": 30,
        "axes.titlesize": 30,
        "axes.labelsize": 30,
        "legend.fontsize": 20,
        "figure.titlesize": 40})

    cube = cube.load()
    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    ax.set_facecolor('grey')

    # ─── Isopycnals ─────────────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    for levels, color, lw in [
        (np.linspace(23, 27, 9), 'black', 0.5),
        ([25.6], 'white', 0.5),
        ([25.7], 'lime', 0.5),
        ([25.8], 'red', 0.5),
        ([25.9], 'blue', 0.5),
        ([26.0], 'black', 0.5),
        ([26.1], 'purple', 0.5),
        ([26.2], 'salmon', 0.5),
        ([26.3], 'yellow', 0.5),
        ([26.4], 'cyan', 0.5)
    ]:
        cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                        colors=color, linewidths=lw, linestyles='--')
        for level in levels:
            if level in manual_labels:
                ax.clabel(cs, fmt='%1.2f', fontsize=10, inline=True,
                          manual=[manual_labels[level]])

    # ─── Color by Along-Track Distance ─────────────────────────────
    subset = cube.sel(along=slice(*region_range))
    times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
    selected_idxs = [i for i, t in enumerate(times)
                     if t.year in target_years and t.month in target_months]

    if selected_idxs:
        selected_idxs.sort()
        idx = selected_idxs[0] if which_date == 'first' else selected_idxs[1] if len(selected_idxs) > 1 else selected_idxs[0]
        small = subset.isel(transect=idx)[['temperature', 'salinity']]
        small.load()

        along_vals = small['along'].values
        cmap = get_cmap('jet')
        norm = Normalize(vmin=region_range[0], vmax=region_range[1])

        for along in along_vals:
            T = small['temperature'].sel(along=along).values
            S = small['salinity'].sel(along=along).values
            mask = ~np.isnan(T) & ~np.isnan(S)
            if np.any(mask):
                color = cmap(norm(along))
                ax.scatter(S[mask], T[mask], s=20, color=color, marker='o')

        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax, orientation='vertical', label='Along-transect distance (m)')

    # ─── Axis and Layout ─────────────────────────────
    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel(r"$\theta$ (°C)")
    ax.set_title(f"T–S Diagram ({which_date.capitalize()} transect)")
    if xlim: ax.set_xlim(xlim)
    if ylim: ax.set_ylim(ylim)
    plt.tight_layout()


plot_ts_shelf_basin(
    cube,
    region_range=(0, 70000),
    target_years=[2024],
    target_months=[3],
    which_date='first',
    xlim=(30, 34),
    ylim=(5, 12)
)

plot_ts_shelf_basin(
    cube,
    region_range=(0, 70000),
    target_years=[2024],
    target_months=[4],
    which_date='first',
    xlim=(30, 34),
    ylim=(5, 12)
)

plot_ts_shelf_basin(
    cube,
    region_range=(0, 70000),
    target_years=[2024],
    target_months=[5],
    which_date='first',
    xlim=(30, 34),
    ylim=(5, 12)
)

plot_ts_shelf_basin(
    cube,
    region_range=(0, 70000),
    target_years=[2024],
    target_months=[6],
    which_date='first',
    xlim=(30, 34),
    ylim=(5, 12)
)

plot_ts_shelf_basin(
    cube,
    region_range=(0, 70000),
    target_years=[2024],
    target_months=[6],
    which_date='first',
    xlim=(30, 34),
    ylim=(5, 12)
)
plot_ts_shelf_basin(
    cube,
    region_range=(0, 70000),
    target_years=[2024],
    target_months=[6],
    which_date='first',
    xlim=(30, 34),
    ylim=(5, 12)
)
plot_ts_shelf_basin(
    cube,
    region_range=(0, 70000),
    target_years=[2024],
    target_months=[7],
    which_date='first',
    xlim=(30, 34),
    ylim=(5, 12)
)
plot_ts_shelf_basin(
    cube,
    region_range=(0, 70000),
    target_years=[2024],
    target_months=[8],
    which_date='first',
    xlim=(30, 34),
    ylim=(5, 12)
)
plot_ts_shelf_basin(
    cube,
    region_range=(0, 70000),
    target_years=[2024],
    target_months=[9],
    which_date='first',
    xlim=(30, 34),
    ylim=(5, 12)
)
plot_ts_shelf_basin(
    cube,
    region_range=(0, 70000),
    target_years=[2024],
    target_months=[10],
    which_date='first',
    xlim=(30, 34),
    ylim=(5, 12)
)

plot_ts_shelf_basin(
    cube,
    region_range=(0, 70000),
    target_years=[2024],
    target_months=[11],
    which_date='first',
    xlim=(30, 34),
    ylim=(5, 12)
)


In [None]:
def plot_ts_shelf_basin_by_transect(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    xlim=(30.5, 34),
    ylim=(5, 16),
    ncols=4,
):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar
    from matplotlib.cm import get_cmap
    from matplotlib.colors import Normalize
    import os

    plt.rcParams.update({
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "axes.titlesize": 14,
        "axes.labelsize": 12,
        "legend.fontsize": 10,
        "figure.titlesize": 16})

    cube = cube.load()
    subset = cube.sel(along=slice(*region_range))
    times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
    selected_idxs = [i for i, t in enumerate(times)
                     if t.year in target_years and t.month in target_months]

    n_panels = len(selected_idxs)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
                             figsize=(4.5 * ncols, 4 * nrows), squeeze=False)
    axes = axes.ravel()

    # Isopycnal grid
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9,
                        26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    cmap = plt.colormaps['jet']
    norm = Normalize(vmin=region_range[0], vmax=region_range[1])

    for i, idx in enumerate(selected_idxs):
        ax = axes[i]
        ax.set_facecolor('grey')

        # Plot isopycnals
        for levels, color, lw in [
            (np.linspace(23, 27, 9), 'black', 0.5),
            ([25.6], 'white', 0.5),
            ([25.7], 'lime', 0.5),
            ([25.8], 'red', 0.5),
            ([25.9], 'blue', 0.5),
            ([26.0], 'black', 0.5),
            ([26.1], 'purple', 0.5),
            ([26.2], 'salmon', 0.5),
            ([26.3], 'yellow', 0.5),
            ([26.4], 'cyan', 0.5)
        ]:
            cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                            colors=color, linewidths=lw, linestyles='--')
            for level in levels:
                if level in manual_labels:
                    ax.clabel(cs, fmt='%1.2f', fontsize=7, inline=True,
                              manual=[manual_labels[level]])

        # Preload transect
        ds = subset.isel(transect=idx)
        temp = ds['temperature']
        sali = ds['salinity']
        along_vals = ds['along'].values

        # Scatter all points by along position
        for a_idx, along in enumerate(along_vals):
            T_col = temp.isel(along=a_idx).values
            S_col = sali.isel(along=a_idx).values
            mask = ~np.isnan(T_col) & ~np.isnan(S_col)
            if np.any(mask):
                ax.scatter(S_col[mask], T_col[mask], s=10, color=cmap(norm(along)), marker='o')
        from matplotlib.collections import PathCollection

        for coll in ax.collections:
            if isinstance(coll, PathCollection):  # Only rasterize scatter points
                coll.set_rasterized(True)
                
        ax.set_title(times[idx].strftime('%d %b %Y'))

        # Ticks only on outer edge
        if i % ncols == 0:
            ax.set_ylabel("$\\theta$ (°C)")
        else:
            ax.set_yticklabels([])
            ax.set_ylabel("")

        if i // ncols == nrows - 1:
            ax.set_xlabel("Salinity (psu)")
        else:
            ax.set_xticklabels([])
            ax.set_xlabel("")

        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

    # Remove unused axes
    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    # Add colorbar outside plot
    fig.subplots_adjust(right=0.88, top=0.92)
    cbar_ax = fig.add_axes([0.9, 0.1, 0.02, 0.8])
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax)
    cbar.set_label('Along-transect distance (m)', fontsize=12)

    fig.suptitle("T–S Diagrams Colored by Along Distance", fontsize=24)

    # Rasterize all axes and save at 600 DPI
    for ax in fig.axes:
        ax.set_rasterization_zorder(1)
    output_path = os.path.expanduser("~/Desktop/ts_grid_raster600dpi.pdf")
    fig.savefig(output_path, format='pdf', dpi=600, bbox_inches='tight')

plot_ts_shelf_basin_by_transect(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    xlim=(30.5, 34),
    ylim=(5, 16),
    ncols=4)

In [None]:
def plot_ts_profile_averages(cube, basin_range=(5000, 15000), sill_range=(23500, 30000), shelf_range=(65000, 70000),
                             target_years=[2024], target_months=[5], xlim=None, ylim=None):
    """
    T–S diagram with mean salinity and temperature profiles for basin, sill, and shelf regions.
    Averages are computed over depth at each region, and plotted as lines.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar

    plt.rcParams.update({
        "xtick.labelsize": 30,
        "ytick.labelsize": 30,
        "axes.titlesize": 30,
        "axes.labelsize": 30,
        "legend.fontsize": 20,
        "figure.titlesize": 40})

    cube = cube.load()
    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    ax.set_facecolor('grey')

    # ─── Isopycnal Grid ───────────────────────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)

    # Manual isopycnal labels at fixed θ
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    # Contour and label isopycnals
    for levels, color, lw in [
        (np.linspace(23, 27, 9), 'black', 0.5),
        ([25.6], 'white', 0.5),
        ([25.7], 'lime', 0.5),
        ([25.8], 'red', 0.5),
        ([25.9], 'blue', 0.5),
        ([26.0], 'black', 0.5),
        ([26.1], 'purple', 0.5),
        ([26.2], 'salmon', 0.5),
        ([26.3], 'yellow', 0.5),
        ([26.4], 'cyan', 0.5)
    ]:
        cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                        colors=color, linewidths=lw, linestyles='--')
        for level in levels:
            if level in manual_labels:
                ax.clabel(cs, fmt='%1.2f', fontsize=10, inline= False, inline_spacing=0.1, manual=[manual_labels[level]])

    # ─── Mean Profiles by Region ────────────────────────────────
    for (start, end), label, color in [
        ((basin_range), 'Basin water (5–15 km)', 'cyan'),
        ((sill_range), 'Sill water (23.5–30 km)', 'red'),
        ((shelf_range), 'QCS water (65–75 km)', 'salmon'),
    ]:
        subset = cube.sel(along=slice(start, end))
        times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
        selected_idxs = [i for i, t in enumerate(times)
                         if t.year in target_years and t.month in target_months]

        if not selected_idxs:
            continue

        small = subset.isel(transect=selected_idxs)[['temperature', 'salinity']]
        small.load()

        T = small['temperature'].mean(dim=('transect', 'along')).values
        S = small['salinity'].mean(dim=('transect', 'along')).values
        mask = ~np.isnan(T) & ~np.isnan(S)
        ax.scatter(S[mask], T[mask], label=label, color=color, s = 20)

    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel(r"$\theta$ (°C)")
    ax.set_title("Mean T–S Profiles by Region")
    ax.legend()
    if xlim: ax.set_xlim(xlim)
    if ylim: ax.set_ylim(ylim)
    plt.tight_layout()

# Example call:
plot_ts_profile_averages(
    cube,
    basin_range=(5000, 15000),
    sill_range=(23500, 30000),
    shelf_range=(65000, 75000),
    target_years=[2024],
    target_months=[5],
    xlim=(30.5, 34),
    ylim=(5, 12))

In [None]:
def plot_ts_profile_grid(cube, basin_range=(5000, 15000), sill_range=(23500, 30000), shelf_range=(65000, 70000),
                         target_years=[2024], target_months=list(range(1, 13)), xlim=None, ylim=None, ncols=3):
    """
    Grid of T–S diagrams with mean salinity and temperature profiles for basin, sill, and shelf regions.
    One subplot per (year, month). Averages are computed over depth at each region.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar

    plt.rcParams.update({
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "axes.titlesize": 14,
        "axes.labelsize": 14,
        "legend.fontsize": 10,
        "figure.titlesize": 16})

    # Extract available (year, month) pairs
    times = pd.to_datetime([str(t)[:8] for t in cube.transect.values])
    grouped_keys = sorted(set((t.year, t.month) for t in times
                              if t.year in target_years and t.month in target_months))

    n_panels = len(grouped_keys)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 5, nrows * 4), sharex=True, sharey=True)
    axes = axes.flatten()

    # ─── Isopycnal Grid and Labels ─────────────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    for j, (year, month) in enumerate(grouped_keys):
        if j >= len(axes):
            break
        ax = axes[j]
        ax.set_facecolor('grey')

        # Plot isopycnals
        for levels, color, lw in [
            (np.linspace(23, 27, 9), 'black', 0.5),
            ([25.6], 'white', 0.5),
            ([25.7], 'lime', 0.5),
            ([25.8], 'red', 0.5),
            ([25.9], 'blue', 0.5),
            ([26.0], 'black', 0.5),
            ([26.1], 'purple', 0.5),
            ([26.2], 'salmon', 0.5),
            ([26.3], 'yellow', 0.5),
            ([26.4], 'cyan', 0.5)
        ]:
            cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                            colors=color, linewidths=lw, linestyles='--')
            for level in levels:
                if level in manual_labels:
                    ax.clabel(cs, fmt='%1.2f', fontsize=8, inline=False, inline_spacing=0.1,
                              manual=[manual_labels[level]])

        has_data = False
        for (start, end), label, color in [
            (basin_range, 'Basin water (5–15 km)', 'cyan'),
            (sill_range, 'Sill water (23.5–30 km)', 'red'),
            (shelf_range, 'QCS water (65–75 km)', 'salmon'),
        ]:
            subset = cube.sel(along=slice(start, end))
            times_subset = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
            selected_idxs = [i for i, t in enumerate(times_subset)
                             if t.year == year and t.month == month]
            if not selected_idxs:
                continue

            has_data = True
            small = subset.isel(transect=selected_idxs)[['temperature', 'salinity']]
            T = small['temperature'].mean(dim=('transect', 'along')).values
            S = small['salinity'].mean(dim=('transect', 'along')).values
            mask = ~np.isnan(T) & ~np.isnan(S)
            ax.plot(S[mask], T[mask], label=label, color=color)

        if has_data:
            ax.set_title(f"{pd.to_datetime(f'{year}-{month:02}').strftime('%b %Y')}")
            ax.legend(fontsize=8)
        else:
            ax.set_title(f"{pd.to_datetime(f'{year}-{month:02}').strftime('%b %Y')} (no data)")

        if j % ncols == 0:
            ax.set_ylabel(r"$\theta$ (°C)")
        if j // ncols == nrows - 1:
            ax.set_xlabel("Salinity (psu)")

    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    if xlim:
        for ax in axes[:n_panels]: ax.set_xlim(xlim)
    if ylim:
        for ax in axes[:n_panels]: ax.set_ylim(ylim)

    fig.suptitle("Monthly Mean T–S Profiles by Region", fontsize=18)
    plt.tight_layout(rect=[0, 0, 1, 0.95])

plot_ts_profile_grid(
    cube,
    basin_range=(5000, 15000),
    sill_range=(23500, 30000),
    shelf_range=(65000, 75000),
    target_years=[2024, 2025],
    target_months=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],  # or list(range(1, 13)) for all months
    xlim=(29, 34),
    ylim=(5, 16),
    ncols=3)

In [None]:
def plot_ts_profile_grid(cube, basin_range=(5000, 15000), sill_range=(23500, 30000), shelf_range=(65000, 70000),
                         target_years=[2024], target_months=list(range(1, 13)), xlim=None, ylim=None, ncols=3):
    """
    Grid of T–S diagrams with salinity and temperature profiles for basin, sill, and shelf regions.
    One subplot per (year, month). Profiles are shown for each transect separately.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar

    plt.rcParams.update({
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "axes.titlesize": 14,
        "axes.labelsize": 14,
        "legend.fontsize": 10,
        "figure.titlesize": 16})

    times = pd.to_datetime([str(t)[:8] for t in cube.transect.values])
    transect_indices = list(enumerate(times))
    grouped = {}
    for i, t in transect_indices:
        if t.year in target_years and t.month in target_months:
            grouped.setdefault((t.year, t.month), []).append(i)
    grouped_keys = sorted(grouped.keys())

    n_panels = len(grouped_keys)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 5, nrows * 4), sharex=True, sharey=True)
    axes = axes.flatten()

    # ─── Isopycnal Grid and Labels ─────────────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    for j, (year, month) in enumerate(grouped_keys):
        if j >= len(axes):
            break
        ax = axes[j]
        ax.set_facecolor('grey')

        # Plot isopycnals
        for levels, color, lw in [
            (np.linspace(23, 27, 9), 'black', 0.5),
            ([25.6], 'white', 0.5),
            ([25.7], 'lime', 0.5),
            ([25.8], 'red', 0.5),
            ([25.9], 'blue', 0.5),
            ([26.0], 'black', 0.5),
            ([26.1], 'purple', 0.5),
            ([26.2], 'salmon', 0.5),
            ([26.3], 'yellow', 0.5),
            ([26.4], 'cyan', 0.5)
        ]:
            cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                            colors=color, linewidths=lw, linestyles='--')
            for level in levels:
                if level in manual_labels:
                    ax.clabel(cs, fmt='%1.2f', fontsize=8, inline=False, inline_spacing=0.1,
                              manual=[manual_labels[level]])

        has_data = False
        for (start, end), label, color in [
            (basin_range, 'Basin water (5–15 km)', 'cyan'),
            (sill_range, 'Sill water (23.5–30 km)', 'red'),
            (shelf_range, 'QCS water (65–75 km)', 'salmon'),
        ]:
            region = cube.sel(along=slice(start, end))
            for i in grouped[(year, month)]:
                T = region['temperature'].isel(transect=i).mean(dim='along').values
                S = region['salinity'].isel(transect=i).mean(dim='along').values
                mask = ~np.isnan(T) & ~np.isnan(S)
                if np.any(mask):
                    marker_style = '^' if i != grouped[(year, month)][0] else '.'
                    ax.scatter(S[mask], T[mask], label=label if i == grouped[(year, month)][0] else None, color=color, s=5, marker=marker_style)
                    has_data = True

        if has_data:
            ax.set_title(f"{pd.to_datetime(f'{year}-{month:02}').strftime('%b %Y')}")
            ax.legend(fontsize=8)
        else:
            ax.set_title(f"{pd.to_datetime(f'{year}-{month:02}').strftime('%b %Y')} (no data)")

        if j % ncols == 0:
            ax.set_ylabel(r"$\theta$ (°C)")
        if j // ncols == nrows - 1:
            ax.set_xlabel("Salinity (psu)")

    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    if xlim:
        for ax in axes[:n_panels]: ax.set_xlim(xlim)
    if ylim:
        for ax in axes[:n_panels]: ax.set_ylim(ylim)

    fig.suptitle("Monthly T–S Profiles by Region (per transect)", fontsize=18)
    plt.tight_layout(rect=[0, 0, 1, 0.95])

plot_ts_profile_grid(
    cube,
    basin_range=(5000, 15000),
    sill_range=(23500, 30000),
    shelf_range=(65000, 75000),
    target_years=[2024, 2025],
    target_months=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],  # or list(range(1, 13)) for all months
    xlim=(30.5, 34),
    ylim=(5, 16),
    ncols=3)

In [None]:
def plot_ts_profile_by_transect(
    cube, 
    basin_range=(5000, 15000), 
    sill_range=(23500, 30000),
    target_years=[2024], 
    target_months=list(range(1, 13)), 
    xlim=None, 
    ylim=None, 
    ncols=3,
    inset_months=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],  # Toggle insets by month
    inset_xlim=(33, 33.5),  # Zoom box salinity range
    inset_ylim=(6.5, 7.5),   # Zoom box temperature range
    inset_limits_by_date=None
):
    """
    Grid of T–S diagrams (one per transect) with salinity and temperature profiles
    for basin, sill, and shelf regions.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar

    plt.rcParams.update({
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "axes.titlesize": 14,
        "axes.labelsize": 14,
        "legend.fontsize": 10,
        "figure.titlesize": 16})

    # ─── Transect Times ─────────────────────────────
    times = pd.to_datetime([str(t)[:8] for t in cube.transect.values])
    valid_indices = [i for i, t in enumerate(times) if t.year in target_years and t.month in target_months]
    times = times[valid_indices]

    n_panels = len(valid_indices)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 5, nrows * 4), sharex=True, sharey=True)
    axes = axes.flatten()

    # ─── Isopycnal Grid and Labels ─────────────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    for j, i in enumerate(valid_indices):
        if j >= len(axes):
            break
        ax = axes[j]
        ax.set_facecolor('grey')
        target_depth = 130  # m

        # Plot isopycnals
        for levels, color, lw in [
            (np.linspace(23, 27, 9), 'black', 0.5),
            ([25.6], 'white', 0.5),
            ([25.7], 'lime', 0.5),
            ([25.8], 'red', 0.5),
            ([25.9], 'blue', 0.5),
            ([26.0], 'black', 0.5),
            ([26.1], 'purple', 0.5),
            ([26.2], 'salmon', 0.5),
            ([26.3], 'yellow', 0.5),
            ([26.4], 'cyan', 0.5)
        ]:
            cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                            colors=color, linewidths=lw, linestyles='--')
            for level in levels:
                if level in manual_labels:
                    ax.clabel(cs, fmt='%1.2f', fontsize=8, inline=False, inline_spacing=0.1,
                              manual=[manual_labels[level]])

        has_data = False
        for (start, end), label, color in [
            (basin_range, 'Basin (0–15 km)', 'cyan'),
            (sill_range, 'Sill (20–30 km)', 'red'),
        ]:
            region = cube.sel(along=slice(start, end))
            avg_temp = region['temperature'].isel(transect=i).mean(dim='along')
            avg_sali = region['salinity'].isel(transect=i).mean(dim='along')

            T = avg_temp.values
            S = avg_sali.values
            mask = ~np.isnan(T) & ~np.isnan(S)
            if np.any(mask):
                ax.scatter(S[mask], T[mask], label=label, color=color, s=5)
                try:
                    t_130 = avg_temp.sel(depth=target_depth, method='nearest').values
                    s_130 = avg_sali.sel(depth=target_depth, method='nearest').values

                    if not np.isnan(t_130) and not np.isnan(s_130):
                        ax.scatter(s_130, t_130, color = 'black', s = 6)
                        ax.text(s_130, t_130, f'{int(target_depth)} m',
                            color=color, ha='left', va='bottom')
                except KeyError:
                    pass  # depth not found
                has_data = True

        if has_data:
            ax.set_title(times[j].strftime('%d %b %Y'))
            if j % ncols == 0:
                ax.set_ylabel(r"$\theta$ (°C)")
            if j // ncols == nrows - 1:
                ax.set_xlabel("Salinity (psu)")
            # ax.legend(loc='upper left', fontsize=8)

            # ─── Add Inset Only for Selected Months ─────────────────────────────
            if times[j].month in inset_months:
                inset = ax.inset_axes([0.5, 0.5, 0.48, 0.48])
                inset.set_facecolor('grey')

                # ─── Define inset bounds early so they're available for depth check ───
                inset_xlim_local = inset_xlim
                inset_ylim_local = inset_ylim

                # Override inset limits if provided for this date
                if inset_limits_by_date:
                    date_str = times[j].strftime('%Y-%m-%d')
                    if date_str in inset_limits_by_date:
                        inset_xlim_local = inset_limits_by_date[date_str].get('xlim', inset_xlim_local)
                        inset_ylim_local = inset_limits_by_date[date_str].get('ylim', inset_ylim_local)

                # Plot isopycnals in inset
                for levels, color, lw in [
                    (np.linspace(23, 27, 9), 'black', 0.5),
                    ([25.6], 'white', 0.5),
                    ([25.7], 'lime', 0.5),
                    ([25.8], 'red', 0.5),
                    ([25.9], 'blue', 0.5),
                    ([26.0], 'black', 0.5),
                    ([26.1], 'purple', 0.5),
                    ([26.2], 'salmon', 0.5),
                    ([26.3], 'yellow', 0.5),
                    ([26.4], 'cyan', 0.5)
                ]:
                    cs = inset.contour(S_grid, T_grid, sigma, levels=levels,
                                       colors=color, linewidths=lw, linestyles='--')
                    for level in levels:
                        if level in manual_labels:
                            inset.clabel(cs, fmt='%1.2f', fontsize=6, inline=False, inline_spacing=0.1,
                                         manual=[manual_labels[level]])

                # Plot points in inset
                for (start, end), color in [
                    (basin_range, 'cyan'),
                    (sill_range, 'red'),
                ]:
                    region = cube.sel(along=slice(start, end))
                    avg_temp = region['temperature'].isel(transect=i).mean(dim='along')
                    avg_sali = region['salinity'].isel(transect=i).mean(dim='along')
                    T = avg_temp.values
                    S = avg_sali.values
                    mask = ~np.isnan(T) & ~np.isnan(S)
                    if np.any(mask):
                        inset.scatter(S[mask], T[mask], color=color, s=3)
                        try:
                            t_130 = avg_temp.sel(depth=target_depth, method='nearest').values
                            s_130 = avg_sali.sel(depth=target_depth, method='nearest').values

                            if (
                                not np.isnan(t_130) and not np.isnan(s_130) and
                                inset_xlim_local[0] <= s_130 <= inset_xlim_local[1] and
                                inset_ylim_local[0] <= t_130 <= inset_ylim_local[1]
                            ):
                                inset.scatter(s_130, t_130, color = 'black', s = 4)
                                inset.text(s_130, t_130, f'{int(target_depth)} m',
                                        color=color, ha='left', va='bottom')
                        except KeyError:
                            pass  # depth not found

                inset_xlim_local = inset_xlim
                inset_ylim_local = inset_ylim

                # Override inset limits if provided for this date
                if inset_limits_by_date:
                    date_str = times[j].strftime('%Y-%m-%d')
                    if date_str in inset_limits_by_date:
                        inset_xlim_local = inset_limits_by_date[date_str].get('xlim', inset_xlim_local)
                        inset_ylim_local = inset_limits_by_date[date_str].get('ylim', inset_ylim_local)

                inset.set_xlim(*inset_xlim_local)
                inset.set_ylim(*inset_ylim_local)
                inset.tick_params(labelsize=6)
                inset.set_xticks([])
                inset.set_yticks([])

                # ─── Draw dashed black box in main axes ─────────────────────────────
                from matplotlib.patches import Rectangle
                box = Rectangle(
                    (inset_xlim_local[0], inset_ylim_local[0]),
                    inset_xlim_local[1] - inset_xlim_local[0],
                    inset_ylim_local[1] - inset_ylim_local[0],
                    edgecolor='black',
                    facecolor='none',
                    linestyle='--',
                    linewidth=1
                )
                ax.add_patch(box)
    
    #     # ─── Master Legend Items ─────────────────────────────
    legend_labels = ['Basin (0–15 km)', 'Sill (20–30 km)']
    legend_colors = ['cyan', 'red']
    legend_handles = [plt.Line2D([0], [0], marker='o', color='w', label=label,
                                 markerfacecolor=color, markersize=15)
                      for label, color in zip(legend_labels, legend_colors)]

    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    if xlim:
        for ax in axes[:n_panels]: ax.set_xlim(xlim)
    if ylim:
        for ax in axes[:n_panels]: ax.set_ylim(ylim)

    fig.legend(handles=legend_handles, loc='upper right', ncol=1, fontsize=20)
    fig.suptitle("Depth-Averaged T–S Profiles", fontsize=32)
    plt.tight_layout(rect=[0, 0, 1, 0.97])  # Leave room at top for legend

plot_ts_profile_by_transect(
    cube,
    basin_range=(0, 15000),
    sill_range=(20000, 30000),
    target_years=[2024, 2025],
    target_months=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    xlim=(30.5, 34),
    ylim=(5, 16),
    ncols=4,
    inset_months=[4, 5, 6, 7, 8, 9],
    inset_limits_by_date={
        '2024-04-16': {'xlim': (32.4, 33.5), 'ylim': (7.3, 8.5)},
        '2024-05-10': {'xlim': (32.7, 33.5), 'ylim': (7.2, 8.1)},
        '2024-05-16': {'xlim': (32.9, 33.6), 'ylim': (6.8, 7.9)},
        '2024-07-17': {'xlim': (33.3, 33.8), 'ylim': (6, 7.1)},
        '2024-07-23': {'xlim': (33.3, 33.8), 'ylim': (6, 7.1)},
    })

In [None]:
def plot_o2_profile_by_transect(
    cube, 
    basin_range=(5000, 15000), 
    sill_range=(23500, 30000), 
    shelf_range=(65000, 70000),
    target_years=[2024], 
    target_months=list(range(1, 13)), 
    xlim=None, 
    ylim=None, 
    ncols=3):
    """
    Grid of O2–Temperature diagrams (one per transect) with salinity–density isopycnals in the background.
    Skips transects that have no valid O2–temperature data.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar

    plt.rcParams.update({
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "axes.titlesize": 14,
        "axes.labelsize": 14,
        "legend.fontsize": 10,
        "figure.titlesize": 16})

    # ─── Transect Times ─────────────────────────────
    times = pd.to_datetime([str(t)[:8] for t in cube.transect.values])
    valid_indices = [i for i, t in enumerate(times) if t.year in target_years and t.month in target_months]
    times = times[valid_indices]

    # ─── Filter out transects with no valid O2–T data ─────────────────────────────
    plot_indices = []
    plot_times = []
    for i, t in zip(valid_indices, times):
        for start, end in [basin_range, sill_range, shelf_range]:
            region = cube.sel(along=slice(start, end))
            T = region['temperature'].isel(transect=i).mean(dim='along').values
            O2 = region['oxygen_concentration'].isel(transect=i).mean(dim='along').values
            if np.any(~np.isnan(T) & ~np.isnan(O2)):
                plot_indices.append(i)
                plot_times.append(t)
                break

    n_panels = len(plot_indices)
    if n_panels == 0:
        print("No transects with valid O2–temperature data.")
        return

    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 5, nrows * 4), sharex=True, sharey=True)
    axes = axes.flatten()

    # ─── Isopycnal Grid and Labels ─────────────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8, 25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    for j, (i, t) in enumerate(zip(plot_indices, plot_times)):
        ax = axes[j]
        ax.set_facecolor('grey')

        has_data = False
        for (start, end), label, color in [
            (basin_range, 'Basin (5–15 km)', 'cyan'),
            (sill_range, 'Sill (23.5–30 km)', 'red'),
            (shelf_range, 'QCS (65–75 km)', 'salmon'),
        ]:
            region = cube.sel(along=slice(start, end))
            T = region['temperature'].isel(transect=i).mean(dim='along').values
            O2 = region['oxygen_concentration'].isel(transect=i).mean(dim='along').values
            mask = ~np.isnan(T) & ~np.isnan(O2)
            if np.any(mask):
                ax.scatter(O2[mask], T[mask], label=label, color=color, s=5)
                has_data = True

        if has_data:
            ax.set_title(t.strftime('%d %b %Y'))
            ax.legend(fontsize=8)
        else:
            ax.set_title(f"{t.strftime('%d %b %Y')} (no data)")

        if j % ncols == 0:
            ax.set_ylabel(r"$\theta$ (°C)")
        if j // ncols == nrows - 1:
            ax.set_xlabel("Oxygen (μmol/kg)")

    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    if xlim:
        for ax in axes[:n_panels]:
            ax.set_xlim(xlim)
    if ylim:
        for ax in axes[:n_panels]:
            ax.set_ylim(ylim)

    fig.suptitle("Oxygen–Temperature Profiles by Transect", fontsize=18)
    plt.tight_layout(rect=[0, 0, 1, 0.95])


plot_o2_profile_by_transect(
    cube,
    basin_range=(0, 15000),
    sill_range=(23500, 30000),
    shelf_range=(65000, 75000),
    target_years=[2023],
    target_months=list(range(1, 13)),
    xlim=(0, 350),  # or None for auto
    ylim=(4, 14),
    ncols=4
)

plot_o2_profile_by_transect(
    cube,
    basin_range=(0, 15000),
    sill_range=(23500, 30000),
    shelf_range=(65000, 75000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    xlim=(0, 350),  # or None for auto
    ylim=(4, 14),
    ncols=4
)

In [None]:
def plot_ts(cube, along_values, target_years=[2023, 2024],
            shallowest_depth=None, xlim=None, ylim=None, target_months=None):
    """
    Scatter-style T–S diagram with sigma-theta contours.
    Different along_values get different marker shapes.
    Prints number of valid T–S points per profile.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import itertools
    import gsw

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})

    if isinstance(along_values, (int, float)):
        along_values = [along_values]

    cube = cube.where(cube['along'].isin(along_values), drop=True)

    transects = cube.transect.values
    times = [pd.to_datetime(str(t)[:8]) for t in transects]
    selected_idxs = [i for i, t in enumerate(times)
                     if t.year in target_years and (target_months is None or t.month in target_months)]

    if not selected_idxs:
        print(f"No valid transects for years {target_years} at along = {along_values}")
        return

    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    ax.set_facecolor('grey')

    # Colormap for time
    cmap = plt.colormaps.get_cmap('jet')
    color_map = cmap(np.linspace(0, 1, len(selected_idxs)))

    # One shape per along_value
    marker_map = {
        along_values[0]: 'o',   # circle
        along_values[1]: '',}

    all_S, all_T = [], []

    # Sigma-theta grid
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 200),
                                 np.linspace(5.3, 20, 200))
    sigma = gsw.sigma0(S_grid, T_grid)

    for levels, color, lw in [
            (np.linspace(24, 27, 7), 'black', 0.5),
            ([25.6], 'white', 0.5),
            ([25.7], 'lime', 0.5),
            ([25.8], 'red', 0.5),
            ([25.9], 'blue', 0.5),
            ([26.0], 'black', 0.5),
            ([26.1], 'purple', 0.5),
            ([26.2], 'salmon', 0.5),
            ([26.3], 'yellow', 0.5),
            ([26.4], 'cyan', 0.5)]:
        cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                        colors=color, linewidths=lw, linestyles='--')
        if lw != 0.5:
            ax.clabel(cs, fmt='%1.2f', fontsize=9, inline=True)

    for i, idx in enumerate(selected_idxs):
        this_along_vals = cube['temperature'].coords['along'].values
        date = times[idx]

        for along_val in along_values:
            along_idx = np.argmin(np.abs(this_along_vals - along_val))

            # Corrected to use all depth levels
            T = cube['temperature'][idx, :, along_idx].values.flatten()
            S = cube['salinity'][idx, :, along_idx].values.flatten()

            mask = ~np.isnan(T) & ~np.isnan(S)
            count = np.count_nonzero(mask)

            if count > 0:
                label = f"{date.strftime('%d %b %Y')} @ {int(along_val)} m"
                # print(f"{label}: {count} valid depth points")

                ax.scatter(
                    S[mask], T[mask],
                    color=color_map[i],
                    label=label,
                    marker=marker_map.get(along_val, 'o'),
                    linewidth=0.2)
                all_S.extend(S[mask])
                all_T.extend(T[mask])

    # Axis limits
    if all_S and all_T:
        all_S = np.array(all_S)
        all_T = np.array(all_T)
        if xlim is None:
            xlim = (np.nanmin(all_S) - 0.1, np.nanmax(all_S) + 0.1)
        if ylim is None:
            ylim = (np.nanmin(all_T) - 0.1, np.nanmax(all_T) + 0.1)

    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel(r"$\theta$ (°C)")
    ax.set_title(f"T–S Diagram at {', '.join([str(a)+' m' for a in along_values])}")
    ax.legend(fontsize=9, loc='upper right')
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    plt.tight_layout()

plot_ts(
    cube,
    along_values=[6500, 23500],   # in meters
    target_years=[2024],
    target_months= [3,4,5,6, 7, 8],
    shallowest_depth=0)



In [None]:
def plot_ts_grid(cube, along_values, target_years=[2023, 2024],
                     shallowest_depth=None, xlim=None, ylim=None, target_months=None,
                     ncols=3):
    """
    T–S diagrams with one subplot per (year, month).
    Same marker shape, custom colors for each along-track location.
    Shared x/y labels shown once.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 16,
        'axes.labelsize': 14,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 10,
        'figure.titlesize': 20})

    if isinstance(along_values, (int, float)):
        along_values = [along_values]

    # ─── Filter by station ─────────────────────────────────────
    cube = cube.where(cube['along'].isin(along_values), drop=True)

    # ─── Select time points ────────────────────────────────────
    transects = cube.transect.values
    times = [pd.to_datetime(str(t)[:8]) for t in transects]
    selected = [(i, t) for i, t in enumerate(times)
                if t.year in target_years and (target_months is None or t.month in target_months)]

    if not selected:
        print(f"No valid transects for years {target_years} at along = {along_values}")
        return

    # ─── Group by (year, month) ────────────────────────────────
    grouped = {}
    for i, t in selected:
        ym = (t.year, t.month)
        grouped.setdefault(ym, []).append((i, t))

    sorted_keys = sorted(grouped.keys())
    n_panels = len(sorted_keys)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 5, nrows * 4), sharex=True, sharey=True)
    axes = axes.flatten()

    # ─── Custom colors ─────────────────────────────────────────
    station_colors = {
        6500: 'red',
        23500: 'cyan'
    }
    marker_style = '.'

    # ─── Sigma-theta grid ──────────────────────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 200),
                                 np.linspace(5, 20, 200))
    sigma = gsw.sigma0(S_grid, T_grid)

    all_S, all_T = [], []

    for j, (year, month) in enumerate(sorted_keys):
        ax = axes[j]
        ax.set_facecolor('lightgrey')

        # ─── Density Contours ───────────────────────────────────
        for levels, contour_color, lw in [
                (np.linspace(24, 27, 7), 'black', 0.5),
                ([25.6], 'white', 0.5), ([25.7], 'lime', 0.5), ([25.8], 'red', 0.5),
                ([25.9], 'blue', 0.5), ([26.0], 'black', 0.5), ([26.1], 'purple', 0.5),
                ([26.2], 'salmon', 0.5), ([26.3], 'yellow', 0.5), ([26.4], 'cyan', 0.5)]:
            cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                            colors=contour_color, linewidths=lw, linestyles='--')
            if lw != 0.5:
                ax.clabel(cs, fmt='%1.2f', fontsize=8, inline=True)

        for i, t in grouped[(year, month)]:
            this_along_vals = cube['temperature'].coords['along'].values

            for along_val in along_values:
                along_idx = np.argmin(np.abs(this_along_vals - along_val))
                T = cube['temperature'][i, :, along_idx].values.flatten()
                S = cube['salinity'][i, :, along_idx].values.flatten()
                mask = ~np.isnan(T) & ~np.isnan(S)

                if np.count_nonzero(mask) > 0:
                    color = station_colors.get(along_val, 'gray')
                    ax.scatter(S[mask], T[mask],
                               s=60,
                               color=color,
                               marker=marker_style,
                               linewidth=0.2,
                               label=f"{t.strftime('%d %b')} @ {int(along_val)} m")
                    all_S.extend(S[mask])
                    all_T.extend(T[mask])

        ax.set_title(f"{pd.to_datetime(f'{year}-{month:02}', format='%Y-%m').strftime('%B %Y')}")
        ax.legend(fontsize=8, loc='upper right')

    # ─── Axis Limits ───────────────────────────────────────────
    if all_S and all_T:
        all_S = np.array(all_S)
        all_T = np.array(all_T)
        if xlim is None:
            xlim = (np.nanmin(all_S) - 0.1, np.nanmax(all_S) + 0.1)
        if ylim is None:
            ylim = (np.nanmin(all_T) - 0.1, np.nanmax(all_T) + 0.1)

        for ax in axes:
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)

    # ─── Shared Labels ─────────────────────────────────────────
    fig.text(0.5, 0.01, "Salinity (psu)", ha='center', fontsize=16)
    fig.text(0.01, 0.5, r"$\theta$ (°C)", va='center', rotation='vertical', fontsize=16)

    # ─── Remove unused axes & title ────────────────────────────
    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    fig.suptitle(f"T–S Diagrams at Sill and FHS Deep Basin")
    plt.tight_layout(rect=[0.05, 0.05, 1, 0.95])

plot_ts_grid(
    cube,
    along_values=[6500, 23500],
    target_years=[2024, 2025],
    target_months=[1,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    xlim=(30, 34),
    ylim=(5, 11))

In [None]:

def plot_to_grid(cube, along_values, target_years=[2023, 2024],
                 shallowest_depth=None, xlim=None, ylim=None,
                 target_months=None, ncols=3):
    """
    O2–Temp diagrams with one subplot per (year, month).
    Shared labels, custom color per along-track location.
    Only plots months that contain valid data.
    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 16,
        'axes.labelsize': 14,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 10,
        'figure.titlesize': 20})

    if isinstance(along_values, (int, float)):
        along_values = [along_values]

    cube = cube.where(cube['along'].isin(along_values), drop=True)

    transects = cube.transect.values
    times = [pd.to_datetime(str(t)[:8]) for t in transects]
    selected = [(i, t) for i, t in enumerate(times)
                if t.year in target_years and (target_months is None or t.month in target_months)]

    if not selected:
        print(f"No valid transects for years {target_years} at along = {along_values}")
        return

    grouped = {}
    for i, t in selected:
        ym = (t.year, t.month)
        grouped.setdefault(ym, []).append((i, t))

    # Filter out empty (year, month) groups
    sorted_keys = []
    for key in sorted(grouped.keys()):
        has_data = False
        for i, _ in grouped[key]:
            for along_val in along_values:
                along_idx = np.argmin(np.abs(cube['along'].values - along_val))
                O2 = cube['oxygen_concentration'][i, :, along_idx].values.flatten()
                T = cube['temperature'][i, :, along_idx].values.flatten()
                mask = ~np.isnan(O2) & ~np.isnan(T)
                if np.count_nonzero(mask) > 0:
                    has_data = True
                    break
            if has_data:
                break
        if has_data:
            sorted_keys.append(key)

    if not sorted_keys:
        print("No valid months with data.")
        return

    n_panels = len(sorted_keys)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 5, nrows * 4), sharex=True, sharey=True)
    axes = axes.flatten()

    station_colors = {
        6500: 'red',
        23500: 'cyan'
    }
    marker_style = '.'

    all_O2, all_T = [], []

    for j, (year, month) in enumerate(sorted_keys):
        ax = axes[j]
        ax.set_facecolor('lightgrey')

        for i, t in grouped[(year, month)]:
            this_along_vals = cube['along'].values

            for along_val in along_values:
                along_idx = np.argmin(np.abs(this_along_vals - along_val))

                O2 = cube['oxygen_concentration'][i, :, along_idx].values.flatten()
                T = cube['temperature'][i, :, along_idx].values.flatten()
                mask = ~np.isnan(O2) & ~np.isnan(T)

                if np.count_nonzero(mask) > 0:
                    color = station_colors.get(along_val, 'gray')
                    ax.scatter(O2[mask], T[mask],
                               s=60,
                               color=color,
                               marker=marker_style,
                               linewidth=0.2,
                               label=f"{t.strftime('%d %b')} @ {int(along_val)} m")
                    all_O2.extend(O2[mask])
                    all_T.extend(T[mask])

        ax.set_title(f"{pd.to_datetime(f'{year}-{month:02}', format='%Y-%m').strftime('%B %Y')}")
        ax.legend(fontsize=8, loc='upper right')

    if all_O2 and all_T:
        if xlim is not None:
            for ax in axes:
                ax.set_xlim(xlim)
        if ylim is not None:
            for ax in axes:
                ax.set_ylim(ylim)

    fig.text(0.5, 0.01, "Oxygen (μmol/kg)", ha='center', fontsize=16)
    fig.text(0.01, 0.5, r"$\theta$ (°C)", va='center', rotation='vertical', fontsize=16)

    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    fig.suptitle(f"Oxygen–Temperature Diagrams at Sill and FHS Deep Basin")
    plt.tight_layout(rect=[0.05, 0.05, 1, 0.95])

plot_to_grid(
    cube,
    along_values=[6500, 23500],
    target_years=[2024, 2025],
    target_months=[1,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
    # xlim=(30, 34),
    ylim=(5, 11))