# Herein lies the code to plot every figure on my Summer 2025 SURA report #

# Imports #

In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import numpy as np
import os
import cmocean as cm
import waypoint_distance as wd
import pandas as pd
from pathlib import Path
from datetime import datetime
from matplotlib.dates import DateFormatter
import gsw
import matplotlib.dates as mdates
%matplotlib widget

# cube = xr.open_dataset(os.path.expanduser('~/Desktop/Summer 2025 Python/calvert_cube.nc'))
new_cube = xr.open_dataset(os.path.expanduser('~/Desktop/Summer 2025 Python/2024_2025_transect_cube.nc'))
new_cube = new_cube.sel(transect=new_cube.transect != '20250317_out')
new_cube = new_cube.assign_coords(along=new_cube['along'] - 25500)
new_cube = new_cube.assign_coords(along = new_cube['along']* (-1))
topo = xr.open_dataset(os.path.expanduser('~/Desktop/Summer 2025 Python/british_columbia_3_msl_2013.nc'))
# ds = xr.open_dataset('~/Desktop/Summer 2025 Python/Hakai_calvert.nc')
# df_stations = ds[['station', 'latitude', 'longitude']].to_dataframe().dropna()
# df_stations_unique = df_stations.drop_duplicates(subset=['station', 'latitude', 'longitude'])
# station_coords = list(zip(df_stations_unique['latitude'],df_stations_unique['longitude'],df_stations_unique['station']))


# Open each station file individually
QCS07 = xr.open_dataset("~/Desktop/QCS07.nc")
FZH08 = xr.open_dataset("~/Desktop/FZH08.nc")
HKP04 = xr.open_dataset("~/Desktop/HKP04.nc")

def prep_station(ds, station_name):
    # Ensure time is datetime64
    if 'time' in ds and ds['time'].dtype.kind in ('O', 'U', 'S'):
        ds = ds.assign_coords(time=ds['time'].astype('datetime64[ns]'))
    
    # Ensure numeric types
    for v in ['depth', 'temperature', 'salinity', 'latitude', 'longitude', 'pressure']:
        if v in ds and ds[v].dtype.kind not in ('f', 'i'):
            ds[v] = ds[v].astype('float64')
    
    # Add station coord if missing
    if 'station' not in ds:
        ds = ds.assign_coords(station=station_name)
    
    return ds

QCS07 = prep_station(QCS07, "QCS07")
FZH08 = prep_station(FZH08, "FZH08")
HKP04 = prep_station(HKP04, "HKP04")

# Mapping #

In [None]:
# Extract waypoints from March 12, 2024 "out" transect
ref = new_cube.sel(transect="20240416_out").dropna(dim='along', subset=['along', 'longitude', 'latitude'])

along_km = ref['along'].values / 1000  # Convert to km
lons = ref['longitude'].values
lats = ref['latitude'].values
target_kms = (25, 18, 10, 0, -10, -20, -30, -40, -50)  # From 25.5 km to -50 km

def plot_bathymetry_with_stations(
    topo, station_coords, selected_station_names=None,
    lon_bounds=None, lat_bounds=None, deepest=500
):
    topo_sel = topo.sel(
        lon=slice(lon_bounds[0], lon_bounds[1]),
        lat=slice(lat_bounds[0], lat_bounds[1])
    )
    depth = -topo_sel['Band1']

    fig, ax = plt.subplots(
        figsize=(2.5*1.5 * 6.4, 2.5 * 1 * 4.8),
        subplot_kw={'projection': ccrs.PlateCarree()}
    )
    ax.set_extent([*lon_bounds, *lat_bounds])
    gl = ax.gridlines(draw_labels=True, linestyle='--', linewidth=0.4, alpha=0)
    gl.xlabel_style = {'size': 20}
    gl.ylabel_style = {'size': 20}

    # Everything with zorder <= 0 will be rasterized in vector outputs
    ax.set_rasterization_zorder(0)

    import numpy.ma as ma
    masked_depth = ma.masked_invalid(depth)
    levels = np.linspace(0, deepest, 21)

    # Filled bathymetry
    cf = ax.contourf(
        topo_sel['lon'], topo_sel['lat'], masked_depth,
        levels=levels, cmap=cm.cm.deep, extend='both',
        transform=ccrs.PlateCarree(), zorder=0
    )

    # Contour lines
    ax.contour(
        topo_sel['lon'], topo_sel['lat'], depth,
        levels=levels, colors='k', linewidths=0.3,
        transform=ccrs.PlateCarree(), zorder=0
    )

    cbar = fig.colorbar(cf, ax=ax, orientation='vertical', pad=0.08)
    cbar.set_label('Depth (m)', fontsize=25)

    # Masked (missing) areas
    missing_mask = ~np.isfinite(depth)
    ax.contourf(
        topo_sel['lon'], topo_sel['lat'], missing_mask,
        levels=[0.5, 1], colors=['olive'],
        transform=ccrs.PlateCarree(), zorder=0
    )

    # Plot station markers + labels (keep vector with higher zorder)
    for lat, lon, name in station_coords:
        name = name.decode().strip() if isinstance(name, bytes) else str(name).strip()
        if selected_station_names is None or name in selected_station_names:
            if lat_bounds[0] <= lat <= lat_bounds[1] and lon_bounds[0] <= lon <= lon_bounds[1]:
                ax.plot(lon, lat, 'o', color='red', markersize=8, zorder=2, transform=ccrs.PlateCarree())
                if name == 'QCS07':
                    ax.text(lon, lat - 0.01, name, fontsize=13, fontweight='bold', color='black',
                            ha='center', va='top', zorder=3, transform=ccrs.PlateCarree(),
                            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))
                else:
                    ax.text(lon + 0.015, lat, name, fontsize=13, fontweight='bold', color='black',
                            ha='left', va='center', zorder=3, transform=ccrs.PlateCarree(),
                            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))

    # Magenta transect line
    ax.plot(lons, lats, linestyle='--', color='magenta', linewidth=3, zorder=2, transform=ccrs.PlateCarree())

    # Region labels (vector)
    ax.text(-128.02, 51.575, "Calvert Island", fontsize=15, fontweight='bold', color='black',
            ha='center', va='center', zorder=3, transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))

    ax.text(-128.7, 51.51, "Queen Charlotte Sound", fontsize=15, fontweight='bold',
            color='black', ha='center', va='center', zorder=3, transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))

    ax.text(-127.7, 51.88, "Fitz Hugh Sound", fontsize=15, fontweight='bold',
            color='black', ha='center', va='center', zorder=3, transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))

    ax.text(-128.2, 51.8, "Hakai Passage", fontsize=15, fontweight='bold',
            color='black', ha='center', va='center', zorder=3, transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))
    
        # West Sea Otter buoy
    ax.plot(-128.7500, 51.3683, 'o', color='blue', markersize=8, zorder=2, transform=ccrs.PlateCarree())
    ax.text(-128.7500 , 51.3683 - 0.015, "C46204 Buoy", fontsize=13, fontweight='bold', color='black',
            ha='center', va='center', zorder=3, transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))

    # Waypoint markers
    for km in target_kms:
        idx = np.argmin(np.abs(along_km - km))
        x, y = lons[idx], lats[idx]
        label = "0 km" if np.isclose(km, 0) else f"{int(km)} km"

        ax.plot(x, y, 'o', color='lime', markersize=8, zorder=2, transform=ccrs.PlateCarree())

        # Decide label placement
        if km in (18, 10, -40):
            # Below marker
            ax.text(x, y - 0.015, label, fontsize=13, fontweight='bold', color='black',
                    ha='center', va='top', zorder=3, transform=ccrs.PlateCarree(),
                    bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))

        elif km in (-50, 0):
            # Above marker
            ax.text(x, y + 0.015, label, fontsize=13, fontweight='bold', color='black',
                    ha='center', va='bottom', zorder=3, transform=ccrs.PlateCarree(),
                    bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))

        else:
            # Right of marker
            ax.text(x + 0.02, y, label, fontsize=13, fontweight='bold', color='black',
                    ha='left', va='center', zorder=3, transform=ccrs.PlateCarree(),
                    bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))


    ax.set_aspect(1 / np.cos(np.deg2rad(np.mean(lat_bounds))))

    fig.savefig("bathymetry_map.pdf", dpi=300, bbox_inches='tight')

# Example call
plot_bathymetry_with_stations(
    topo=topo,
    station_coords=station_coords,
    selected_station_names=['QCS07', 'FZH08'],
    lon_bounds=(-128.9, -127.5),
    lat_bounds=(51.3, 52),
    deepest=500
)

# Big map #

<!-- # Larger scale map # -->

In [None]:

def plot_west_coast(
    topo, station_coords, selected_station_names=None,
    lon_bounds=None, lat_bounds=None, deepest=4500,
    save_path="bathymetry_map.pdf", raster_dpi=300
):
    # ─── Subset and prepare data ─────────────────────────────────────
    topo_sel = topo.sel(
        lon=slice(lon_bounds[0], lon_bounds[1]),
        lat=slice(lat_bounds[0], lat_bounds[1])
    ).coarsen(lon=5, lat=5, boundary='trim').mean()

    depth = -topo_sel['Band1']
    masked_depth = np.ma.masked_invalid(depth)
    levels = np.linspace(0, deepest, 21)

    # ─── Set up figure and map ───────────────────────────────────────
    fig, ax = plt.subplots(
        figsize=(14, 10),
        subplot_kw={'projection': ccrs.PlateCarree()}
    )
    ax.set_extent([*lon_bounds, *lat_bounds])

    # Rasterize anything with zorder ≤ 1
    ax.set_rasterization_zorder(1)

    # ─── Rasterized Filled Contours ──────────────────────────────────
    ax.contourf(
        topo_sel['lon'], topo_sel['lat'], masked_depth,
        levels=levels, cmap=cm.cm.deep, extend='both',
        transform=ccrs.PlateCarree(), zorder=0, rasterized=True
    )

    # ─── Rasterized Bathymetry Contours ──────────────────────────────
    ax.contour(
        topo_sel['lon'], topo_sel['lat'], depth,
        levels=levels, colors='k', linewidths=0.1,
        transform=ccrs.PlateCarree(), zorder=0, rasterized=True
    )

    # ─── Rasterized Missing Data Overlay ─────────────────────────────
    missing_mask = ~np.isfinite(depth)
    ax.contourf(
        topo_sel['lon'], topo_sel['lat'], missing_mask,
        levels=[0.5, 1], colors=['olive'],
        transform=ccrs.PlateCarree(), zorder=0, rasterized=True
    )

    # ─── Inset Box: Queen Charlotte Sound Zoom Region (vector) ───────
    inset_lons = [-128.9, -127.5, -127.5, -128.9, -128.9]
    inset_lats = [51.3, 51.3, 52.0, 52.0, 51.3]
    ax.plot(
        inset_lons, inset_lats, color='red', linewidth=2,
        transform=ccrs.PlateCarree(), zorder=5
    )

    # ─── Region Labels (vector) ──────────────────────────────────────
    ax.text(-131.9, 53, "Haida Gwaii", fontsize=18, fontweight='bold',
            ha='center', va='center', transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3', linewidth=0.5))

    ax.text(-126, 50, "Vancouver Island", fontsize=18, fontweight='bold',
            ha='center', va='center', transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3', linewidth=0.5))

    ax.text(-126.5, 53.5, "British Columbia", fontsize=18, fontweight='bold',
            ha='left', va='top', transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.4', linewidth=0.5))
    
    ax.text(-130.5, 50, "Queen Charlotte Sound", fontsize=18, fontweight='bold',
            ha='center', va='center', transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.4', linewidth=0.5))
    ax.set_aspect(1 / np.cos(np.deg2rad(np.mean(lat_bounds))))

            # West Sea Otter buoy
    ax.plot(-128.7500, 51.3683, 'o', color='blue', markersize=8, zorder=2, transform=ccrs.PlateCarree())
    # ax.text(-128.7500 , 51.3683 - 0.015, "C46204 Buoy", fontsize=13, fontweight='bold', color='black',
    #         ha='center', va='center', zorder=3, transform=ccrs.PlateCarree(),
    #         bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))

    # ─── Save PDF with reduced raster DPI ────────────────────────────
    # fig.savefig(save_path, dpi=raster_dpi, bbox_inches='tight')
    

# Example call:
plot_west_coast(
    topo=topo,
    station_coords=station_coords,
    selected_station_names=['QCS07', 'FZH08'],
    lon_bounds=(-133, -123),
    lat_bounds=(48.5, 54),
    deepest=500,
    save_path="biggest_map.pdf",
    raster_dpi=300
)

In [None]:
# ─── Calvert waypoint polyline (from your YAML) ─────────────────────
CALVERT_WAYPOINTS = [
    {"station":"FZH08","lon":-127.9805,"lat":51.7537},
    {"station":"HP01","lon":-128.011,"lat":51.7413},
    {"station":"HP02","lon":-128.0428,"lat":51.7295},
    {"station":"HP03","lon":-128.068,"lat":51.721},
    {"station":"HP04","lon":-128.0905,"lat":51.7102},
    {"station":"HP05","lon":-128.1192,"lat":51.7053},
    {"station":"HP08","lon":-128.1452,"lat":51.7045},
    {"station":"HP09","lon":-128.181,"lat":51.7067},
    {"station":"HP10","lon":-128.2152,"lat":51.7062},
    {"station":"QCS01","lon":-128.2767,"lat":51.7117},
    {"station":"CL02","lon":-128.368,"lat":51.705},
    {"station":"CL03","lon":-128.5,"lat":51.464},
    {"station":"CL04","lon":-128.6605,"lat":51.4083},
    {"station":"CL05","lon":-128.8237,"lat":51.417},
    {"station":"CL06","lon":-128.8812,"lat":51.3935},
    {"station":"CL07","lon":-128.9367,"lat":51.3713},
    {"station":"CL08","lon":-128.9937,"lat":51.348},
    {"station":"CL09","lon":-129.3543,"lat":51.2515},
    {"station":"CL10","lon":-129.5478,"lat":51.1842},
    {"station":"CL11","lon":-129.855,"lat":51.081},
    {"station":"CL12","lon":-130.0,"lat":51.0167},
    {"station":"CL13","lon":-130.979,"lat":50.9187},
    {"station":"CL14","lon":-131.1492,"lat":50.8493},
    {"station":"P16","lon":-134.6667,"lat":49.2833},
]

def plot_west_coast(
    topo, station_coords, selected_station_names=None,
    lon_bounds=None, lat_bounds=None, deepest=4500,
    save_path="bathymetry_map.pdf", raster_dpi=300,
    show_calvert_line=True
):
    # ─── Subset and prepare data ─────────────────────────────────────
    topo_sel = topo.sel(
        lon=slice(lon_bounds[0], lon_bounds[1]),
        lat=slice(lat_bounds[0], lat_bounds[1])
    ).coarsen(lon=5, lat=5, boundary='trim').mean()

    depth = -topo_sel['Band1']
    masked_depth = np.ma.masked_invalid(depth)
    levels = np.linspace(0, deepest, 21)

    # ─── Set up figure and map ───────────────────────────────────────
    fig, ax = plt.subplots(
        figsize=(14, 10),
        subplot_kw={'projection': ccrs.PlateCarree()}
    )
    ax.set_extent([*lon_bounds, *lat_bounds])

    # Rasterize anything with zorder ≤ 1
    ax.set_rasterization_zorder(1)

    # ─── Rasterized Filled Contours ──────────────────────────────────
    ax.contourf(
        topo_sel['lon'], topo_sel['lat'], masked_depth,
        levels=levels, cmap=cm.cm.deep, extend='both',
        transform=ccrs.PlateCarree(), zorder=0, rasterized=True
    )

    # ─── Rasterized Bathymetry Contours ──────────────────────────────
    ax.contour(
        topo_sel['lon'], topo_sel['lat'], depth,
        levels=levels, colors='k', linewidths=0.1,
        transform=ccrs.PlateCarree(), zorder=0, rasterized=True
    )

    # ─── Rasterized Missing Data Overlay ─────────────────────────────
    missing_mask = ~np.isfinite(depth)
    ax.contourf(
        topo_sel['lon'], topo_sel['lat'], missing_mask,
        levels=[0.5, 1], colors=['olive'],
        transform=ccrs.PlateCarree(), zorder=0, rasterized=True
    )

    # ─── Inset Box: Queen Charlotte Sound Zoom Region (vector) ───────
    inset_lons = [-128.9, -127.5, -127.5, -128.9, -128.9]
    inset_lats = [51.3, 51.3, 52.0, 52.0, 51.3]
    ax.plot(inset_lons, inset_lats, color='red', linewidth=2,
            transform=ccrs.PlateCarree(), zorder=5)

    # ─── Region Labels (vector) ──────────────────────────────────────
    ax.text(-131.9, 53, "Haida Gwaii", fontsize=18, fontweight='bold',
            ha='center', va='center', transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3', linewidth=0.5))

    ax.text(-126, 50, "Vancouver Island", fontsize=18, fontweight='bold',
            ha='center', va='center', transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3', linewidth=0.5))

    ax.text(-126.5, 53.5, "British Columbia", fontsize=18, fontweight='bold',
            ha='left', va='top', transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.4', linewidth=0.5))
    
    ax.text(-130.5, 50, "Queen Charlotte Sound", fontsize=18, fontweight='bold',
            ha='center', va='center', transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.4', linewidth=0.5))

    # ─── Calvert waypoint polyline (vector; NOT rasterized) ──────────
    if show_calvert_line:
        calvert_lons = [wp["lon"] for wp in CALVERT_WAYPOINTS]
        calvert_lats = [wp["lat"] for wp in CALVERT_WAYPOINTS]
        ax.plot(
            calvert_lons, calvert_lats,
            linestyle='--', linewidth=2, marker= None, markersize=3,
            transform=ccrs.PlateCarree(), zorder=6, color='magenta')

    ax.set_aspect(1 / np.cos(np.deg2rad(np.mean(lat_bounds))))

    # ─── Save PDF with reduced raster DPI ────────────────────────────
    fig.savefig(save_path, dpi=raster_dpi, bbox_inches='tight')

# Example call (unchanged, now adds Calvert line by default):
plot_west_coast(
    topo=topo,
    station_coords=station_coords,
    selected_station_names=['QCS07', 'FZH08'],
    lon_bounds=(-133, -123),
    lat_bounds=(48.5, 54),
    deepest=500,
    save_path="biggest_map.pdf",
    raster_dpi=300,
    show_calvert_line=True,
)


# Hakai Passage Map #

In [None]:
# Extract waypoints from March 12, 2024 "out" transect
ref = new_cube.sel(transect="20240312_out").dropna(dim='along', subset=['along', 'longitude', 'latitude'])

along_km = ref['along'].values / 1000  # Convert to km
lons = ref['longitude'].values
lats = ref['latitude'].values
target_kms = (25, 18, 10, 0)  

def plot_bathymetry_with_stations(
    topo,
    station_coords,
    selected_station_names=None,
    lon_bounds=None,
    lat_bounds=None,
    deepest=500
):
    # ─── Figure and Axes ─────────────────────────────────────
    fig, ax = plt.subplots(
        figsize=(2.5*1.5 * 6.4, 2.5* 1 * 4.8),
        subplot_kw={'projection': ccrs.PlateCarree()}
    )
    ax.set_extent([*lon_bounds, *lat_bounds])
    gl = ax.gridlines(draw_labels=True, linestyle='--', linewidth=0.4, alpha=0)
    gl.xlabel_style = {'size': 20}
    gl.ylabel_style = {'size': 20}

    # ─── Mask Depth ─────────────────────────────────────────
    topo_sel = topo.sel(
        lon=slice(lon_bounds[0], lon_bounds[1]),
        lat=slice(lat_bounds[0], lat_bounds[1])
    )
    depth = -topo_sel['Band1']
    import numpy.ma as ma
    masked_depth = ma.masked_invalid(depth)
    levels = np.linspace(0, deepest, 21)

    # Rasterized heavy layers
    cf = ax.contourf(
        topo_sel['lon'], topo_sel['lat'], masked_depth,
        levels=levels, cmap=cm.cm.deep, extend='both',
        transform=ccrs.PlateCarree(),
        zorder=0, rasterized=True
    )
    contours = ax.contour(
        topo_sel['lon'], topo_sel['lat'], depth,
        levels=levels, colors='k', linewidths=0.3,
        transform=ccrs.PlateCarree(),
        zorder=0, rasterized=True
    )

    # Rasterized missing data mask
    missing_mask = ~np.isfinite(depth)
    ax.contourf(
        topo_sel['lon'], topo_sel['lat'], missing_mask,
        levels=[0.5, 1], colors=['olive'],
        transform=ccrs.PlateCarree(),
        zorder=0, rasterized=True
    )

    # Colorbar (still vector text)
    cbar = fig.colorbar(cf, ax=ax, orientation='vertical', pad=0.08)
    cbar.set_label('Depth (m)', fontsize=25)

    # ─── Stations (keep vector for crisp text/markers) ──────
    for lat, lon, name in station_coords:
        name = name.decode().strip() if isinstance(name, bytes) else str(name).strip()
        if selected_station_names is None or name in selected_station_names:
            if lat_bounds[0] <= lat <= lat_bounds[1] and lon_bounds[0] <= lon <= lon_bounds[1]:
                ax.plot(lon, lat, 'o', color='red', markersize=8, transform=ccrs.PlateCarree())
                if name == 'QCS07':
                    ax.text(lon, lat - 0.01, name, fontsize=13, fontweight='bold', color='black',
                            ha='center', va='top', transform=ccrs.PlateCarree(),
                            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))
                else:
                    ax.text(lon + 0.005, lat, name, fontsize=13, fontweight='bold', color='black',
                            ha='left', va='center', transform=ccrs.PlateCarree(),
                            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))

    # ─── Transect line (vector for clarity) ─────────────────
    ax.plot(
        lons, lats, linestyle='--', color='magenta', linewidth=3,
        transform=ccrs.PlateCarree(), zorder=2
    )

    # ─── Labels (vector) ────────────────────────────────────
    ax.text(-128.02, 51.69, "Calvert Island", fontsize=15, fontweight='bold', color='black',
            ha='center', va='center', transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))

    ax.text(-127.94, 51.72, "Fitz Hugh Sound", fontsize=15, fontweight='bold', color='black',
            ha='center', va='center', transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))

    ax.text(-128.15, 51.735, "Hakai Passage", fontsize=15, fontweight='bold', color='black',
            ha='center', va='center', transform=ccrs.PlateCarree(),
            bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))

    # ─── Waypoint markers (vector) ──────────────────────────
    for km in target_kms:
        idx = np.argmin(np.abs(along_km - km))
        x, y = lons[idx], lats[idx]
        label = "0 km" if np.isclose(km, 0) else f"{int(km)} km"
        ax.plot(x, y, 'o', color='lime', markersize=8, transform=ccrs.PlateCarree())
        if km in (18, 10, 0):
            ax.text(x, y - 0.004, label, fontsize=13, fontweight='bold', color='black',
                    ha='center', va='top', transform=ccrs.PlateCarree(),
                    bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))
        else:
            ax.text(x + 0.004, y, label, fontsize=13, fontweight='bold', color='black',
                    ha='left', va='center', transform=ccrs.PlateCarree(),
                    bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.2', linewidth=0.5))

    ax.set_aspect(1 / np.cos(np.deg2rad(np.mean(lat_bounds))))

    # ─── Save high-DPI raster-heavy PDF ─────────────────────
    fig.savefig(
        "hakai_passage_map.pdf",
        dpi=600,
        bbox_inches='tight'
    )

plot_bathymetry_with_stations(
    topo=topo,
    station_coords=station_coords,
    selected_station_names=['FZH08', 'HKP04', "HKP01"],
    lon_bounds=(-128.3, -127.85),
    lat_bounds=(51.66, 51.8),
    deepest=500)

# Plot temperature sections #

In [None]:
def plot_temperature_sections(cube, topo, target_years, ncols=3, xlim=(77, 0), vmin=5.3, vmax=10):
    """
    Plot a grid of temperature sections (memory-light version), formatted for a 6x8 inch figure.
    """

    import matplotlib.pyplot as plt
    import numpy as np
    import xarray as xr
    import pandas as pd
    import cmocean as cm
    import os

    plt.rcParams.update({
        'font.size': 8,
        'axes.titlesize': 10,
        'axes.labelsize': 9,
        'xtick.labelsize': 8,
        'ytick.labelsize': 8,
        'legend.fontsize': 8,
        'figure.titlesize': 10
    })

    # ─── Filter transects by year ─────────────────────────────
    transects = cube.transect.values
    selected = [str(t) for t in transects if int(str(t)[:4]) in target_years]
    cube_sel = cube.sel(transect=selected)
    times = [datetime.strptime(t[:8], "%Y%m%d") for t in cube.transect.values]

    n = len(selected)
    nrows = int(np.ceil(n / ncols))

    fig, axes = plt.subplots(nrows, ncols, figsize=(6, 8), sharex=True, sharey=True)
    axes = axes.flatten()

    for i, tran_name in enumerate(selected):
        ds = cube_sel.sel(transect=tran_name)

        ax = axes[i]
        temp = ds['potential_temperature'].values
        pdens = ds['potential_density'].values - 1000
        depth = ds['depth'].values
        along = ds['along'].values
        lon = ds['longitude'].values
        lat = ds['latitude'].values

        interp_bathy = topo['Band1'].interp(
            lon=xr.DataArray(lon, dims='along'),
            lat=xr.DataArray(lat, dims='along'),
            method='nearest')
        ocean_floor = -np.asarray(interp_bathy.values, dtype=np.float32)

        # ─── Plot ───────────────────────────────────────────
        ax.fill_between(along / 1000, ocean_floor, 500, where=~np.isnan(ocean_floor),
                        facecolor='grey', zorder=1)

        pc = ax.pcolormesh(along / 1000, depth, temp,
                           shading='auto', cmap=cm.cm.thermal,
                           vmin=vmin, vmax=vmax, zorder=2, rasterized=True)

        ax.plot(along / 1000, ocean_floor, color='black', linewidth=0.8)

        for levels, color, lw in [
            (np.linspace(24, 27, 7), 'black', 0.3),
            ([25.6], 'white', 0.3),
            ([25.7], 'lime', 0.3),
            ([25.8], 'red', 0.3),
            ([25.9], 'blue', 0.3),
            # ([26.0], 'black', 0.3),
            ([26.1], 'purple', 0.3),
            ([26.2], 'salmon', 0.3),
            ([26.3], 'yellow', 0.3),
            ([26.4], 'cyan', 0.3)
        ]:
            iso = ax.contour(along/ 1000, depth, pdens, levels=levels,
                             colors=color, linewidths=lw)
            # ax.clabel(iso, fmt='%1.2f', fontsize=5)

        date_str = pd.to_datetime(tran_name[:8], format='%Y%m%d').strftime('%b %d')
        leg = 'Out' if 'out' in tran_name else 'Return'
        ax.text(0.5, + 0.2, f"{date_str} ({leg})",
                transform=ax.transAxes,
                ha='center', va='top',
                fontsize=8, color='black')
        ax.label_outer()   # hides inner labels automatically

    ax = axes[min(n-1, len(axes)-1)]
    ax.invert_yaxis()
    ax.set_xlim(xlim)
    ax.set_ylim(440, 0)

    # Identify bottom-left subplot
    bottom_left = axes[(nrows - 1) * ncols]  
    bottom_left.set_xlabel("Distance along (km)")
    bottom_left.set_ylabel("Depth (m)")

    for j in range(i + 1, len(axes)):
        fig.delaxes(axes[j])

    # ─── Tighter layout with room for horizontal colorbar ─────
    plt.tight_layout(rect=[0.08, 0.08, 0.95, 0.9], h_pad=0.05, w_pad=0.05)

    # horizontal colorbar spanning all subplots, ON TOP
    cbar = fig.colorbar(
        pc, ax=axes, orientation='horizontal',
        location='top',        # <- this puts it at the top
        fraction=0.05, pad=0.02)
    cbar.set_label("Temperature (°C)", fontsize=9, labelpad=2)
    cbar.ax.tick_params(labelsize=8)

    # move ticks/label to the top
    cbar.ax.xaxis.set_ticks_position('top')
    cbar.ax.xaxis.set_label_position('top')

    fig.savefig(os.path.expanduser("~/Desktop/temperature_section_grid_3col.pdf"),
                format='pdf', bbox_inches='tight', dpi=300)

plot_temperature_sections(
    cube=new_cube,
    topo=topo,
    target_years=[2024, 2025],
    ncols=3,
    xlim=(-53, 25.5),
    vmin=5.3,
    vmax=10)

# Temperature time series of different along values #

In [None]:
def plot_temp_time_series_subplots(
    cube,
    along_values,
    target_years=[2019, 2020, 2021, 2022, 2023, 2024, 2025],
    ylims=[(400,0), (130,0), (200,0)],
    start_month=None
):
    import matplotlib.pyplot as plt
    from matplotlib.dates import DateFormatter
    import pandas as pd
    import numpy as np
    from datetime import datetime
    import cmocean as cm
    import os

    # Calculate height ratios based on ylims
    max_depths = [yl[0] for yl in ylims]
    max_depth = max(max_depths)
    height_ratios = [d / max_depth for d in max_depths]

    n = len(along_values)
    fig, axes = plt.subplots(
        nrows=n, figsize=(6,6),
        constrained_layout=True,
        gridspec_kw={'height_ratios': height_ratios}
    )
    if n == 1:
        axes = [axes]

    # Collect all valid times to compute global xlim
    all_times = []
    time_lookup = {}
    for along_value in along_values:
        cube_sel = cube.sel(along=along_value)
        times_this = [datetime.strptime(t[:8], "%Y%m%d") for t in cube_sel.transect.values]
        valid_times = [
            t for t in times_this
            if t.year in target_years and (start_month is None or t >= datetime(target_years[0], start_month, 1))
        ]
        all_times.extend(valid_times)
        time_lookup[along_value] = valid_times

    if not all_times:
        print("No valid time data found.")
        return

    global_xlim = (min(all_times), max(all_times))

    for ax_i, (ax, along_value, ylim) in enumerate(zip(axes, along_values, ylims)):
        cube_sel = cube.sel(along=along_value)

        temp = cube_sel['potential_temperature'].values
        density = cube_sel['potential_density'].values - 1000
        depth = cube_sel['depth'].values
        times = [datetime.strptime(t[:8], "%Y%m%d") for t in cube_sel.transect.values]

        filtered = [
            (t, te, de)
            for t, te, de in zip(times, temp, density)
            if (
                t.year in target_years
                and (start_month is None or t >= datetime(target_years[0], start_month, 1))
                and not np.all(np.isnan(te))
            )
        ]

        if not filtered:
            print(f"No data for along={along_value}")
            continue

        times_data, temp_data, density_data = zip(*sorted(filtered))
        this_times = sorted(set(times_data))
        temp_grid = np.full((temp_data[0].shape[0], len(this_times)), np.nan)
        density_grid = np.full_like(temp_grid, np.nan)

        for i, t in enumerate(times_data):
            time_idx = this_times.index(t)
            temp_grid[:, time_idx] = temp_data[i]
            density_grid[:, time_idx] = density_data[i]

        pc = ax.pcolormesh(this_times, depth, temp_grid,
                           cmap=cm.cm.thermal, shading='nearest',
                           vmin=5.3, vmax=10)

        for levels, color, lw in [
                (np.linspace(24, 27, 7), 'black', .31),
                ([25.6], 'white', 0.3),
                ([25.7], 'lime', 0.3),
                ([25.8], 'red', 0.3),
                ([25.9], 'blue', 0.3),
                ([26.0], 'black', 0.31),
                ([26.1], 'purple', 0.3),
                ([26.2], 'salmon', 0.3),
                ([26.3], 'yellow', 0.3),
                ([26.4], 'cyan', 0.3)]:
            cf_iso = ax.contour(this_times, depth, density_grid, levels=levels, colors=color, linewidths=lw)
            if lw != 0.3:
                ax.clabel(cf_iso, fmt='%1.2f', fontsize=6, inline=False)

        # ─── Label each plot with adjusted distance from sill ───
        km_from_sill = (along_value) / 1000
        if abs(km_from_sill) < 0.25:
            label = "0 km (sill)"
        else:
            label = f"{int(round(km_from_sill))} km"

        ax.text(0.01, 0.95,
                label,
                transform=ax.transAxes,
                fontsize=9,
                fontweight="bold",
                va="top", ha="left")

        ax.invert_yaxis()
        ax.set_ylim(ylim)
        ax.set_xlim(global_xlim)
        ax.xaxis.set_major_formatter(DateFormatter('%b %d\n%Y'))
        ax.axhline(130, color='k', linestyle='--', linewidth=0.8)

        # Remove depth ticks except on the last subplot
        if ax_i != len(axes)-1:
            ax.set_yticklabels([])
            ax.set_ylabel("")
        else:
            ax.set_ylabel("Depth (m)")

        if ax is not axes[-1]:
            ax.set_xticklabels([])

        if ax is axes[0]:
            # top axis ticks at actual sample times
            ax_top = ax.secondary_xaxis('top')
            ax_top.set_xticks(this_times)
            ax_top.set_xticklabels(
                [t.strftime('%b %-d') for t in this_times],
                rotation=90, fontsize=7
            )
            ax_top.tick_params(axis='x', direction='out', length=2, width=0.5, pad=1)

    plt.rcParams.update({
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "axes.titlesize": 10,
        "axes.labelsize": 9,
        "figure.titlesize": 12
    })

    # Slimmer, squished-in colorbar
    cbar = fig.colorbar(pc, ax=axes, label="Temperature (°C)",
                        orientation='vertical', pad=0.01)
    cbar.ax.tick_params(labelsize=8)

    fig.savefig(os.path.expanduser("~/Desktop/along_time_series.pdf"),
                format='pdf', bbox_inches='tight', dpi=300)

plot_temp_time_series_subplots(
    new_cube,
    along_values=[-40000, 0, 18000],
    ylims=[(200,0), (140,0), (400,0)],
    target_years=[2024,2025], start_month = 1)

# Sill comparison plot #

In [None]:
# def plot_combined_sill_time_series(
#     cube,
#     ds_hakai,
#     along_values=[25500, 8000],
#     hakai_station="QCS07",
#     ylims=[(200,0), (130,0), (400,0)],
#     target_years=[2024, 2025],
#     start_month=None
# ):
#     import matplotlib.pyplot as plt
#     from matplotlib.dates import DateFormatter
#     import pandas as pd
#     import numpy as np
#     from datetime import datetime
#     import cmocean as cm
#     import gsw

#     # Add one more ylimit for Hakai subplot
#     ylims = ylims + [(125, 0)]

#     # ─── Calculate height ratios ───
#     max_depths = [yl[0] for yl in ylims]
#     max_depth = max(max_depths)
#     height_ratios = [d / max_depth for d in max_depths]

#     nrows = len(ylims)
#     fig, axes = plt.subplots(
#         nrows=nrows,
#         figsize=(1.3 * 16, 1.2 * 14),
#         constrained_layout=True,
#         gridspec_kw={'height_ratios': height_ratios}
#     )
#     fig.suptitle("Hakai Passage Sill, Basin, and Southern Sill Comparison", fontsize=40)

#     # ─── Gather global time limits ───
#     all_times = []

#     # ─── Glider Data Subplots ───
#     for ax, along_value, ylim in zip(axes[:-1], along_values, ylims[:-1]):
#         cube_sel = cube.sel(along=along_value)
#         temp = cube_sel['temperature'].values
#         density = cube_sel['potential_density'].values - 1000
#         depth = cube_sel['depth'].values
#         times = [datetime.strptime(t[:8], "%Y%m%d") for t in cube_sel.transect.values]

#         filtered = [
#             (t, te, de)
#             for t, te, de in zip(times, temp, density)
#             if (
#                 t.year in target_years
#                 and (start_month is None or t >= datetime(target_years[0], start_month, 1))
#                 and not np.all(np.isnan(te))
#             )
#         ]

#         if not filtered:
#             print(f"No data for along={along_value}")
#             continue

#         times_data, temp_data, density_data = zip(*sorted(filtered))
#         this_times = sorted(set(times_data))
#         temp_grid = np.full((temp_data[0].shape[0], len(this_times)), np.nan)
#         density_grid = np.full_like(temp_grid, np.nan)

#         for i, t in enumerate(times_data):
#             time_idx = this_times.index(t)
#             temp_grid[:, time_idx] = temp_data[i]
#             density_grid[:, time_idx] = density_data[i]

#         pc = ax.pcolormesh(this_times, depth, temp_grid,
#                            cmap=cm.cm.thermal, shading='nearest',
#                            vmin=5.3, vmax=10)

#         for levels, color, lw in [
#                 (np.linspace(24, 27, 7), 'black', .51),
#                 ([25.6], 'white', 0.5),
#                 ([25.7], 'lime', 0.5),
#                 ([25.8], 'red', 0.5),
#                 ([25.9], 'blue', 0.5),
#                 ([26.0], 'black', 0.51),
#                 ([26.1], 'purple', 0.5),
#                 ([26.2], 'salmon', 0.5),
#                 ([26.3], 'yellow', 0.5),
#                 ([26.4], 'cyan', 0.5)]:
#             cf_iso = ax.contour(this_times, depth, density_grid, levels=levels, colors=color, linewidths=lw)
#             if lw != 0.3:
#                 ax.clabel(cf_iso, fmt='%1.2f')

#         all_times.extend(this_times)

#         # ─── Along-distance label ───
#         km_from_sill = (along_value) / 1000
#         if abs(km_from_sill) < 0.25:
#             label = "0 km (Hakai Passage Sill)"
#         elif km_from_sill < 0:
#             label = f"{int(round(km_from_sill))} km (FHS Basin)"
#         else:
#             label = f"+{int(round(km_from_sill))} km"

#         ax.text(0.01, 0.95, label, transform=ax.transAxes,
#                 fontsize=26, fontweight="bold", va="top", ha="left")

#         ax.invert_yaxis()
#         ax.set_ylim(ylim)
#         ax.xaxis.set_major_formatter(DateFormatter('%b %d\n%Y'))
#         ax.set_xlim(min(all_times), max(all_times))

#         if ax is not axes[-1]:
#             ax.set_xticklabels([])

#     # ─── Hakai Station Subplot ───
#     ax = axes[-1]
#     ds_station = ds_hakai.where(ds_hakai['station'] == hakai_station, drop=True)
#     ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
#     ds_station = ds_station.where(ds_station['time'].dt.year.isin(target_years), drop=True)

#     # Interpolate onto regular depth
#     regular_depth = np.arange(0, 500, 1)
#     grouped = ds_station.groupby('time')

#     temp_profiles = []
#     dens_profiles = []
#     valid_times = []

#     for t, group in grouped:
#         if group.sizes['row'] < 2:
#             continue
#         d = group['depth'].values
#         T = group['temperature'].values
#         S = group['salinity'].values
#         σ = gsw.sigma0(S, T)
#         temp_profiles.append(np.interp(regular_depth, d, T, left=np.nan, right=np.nan))
#         dens_profiles.append(np.interp(regular_depth, d, σ, left=np.nan, right=np.nan))
#         time_dt = pd.Timestamp(t).to_pydatetime()
#         valid_times.append(time_dt)
#         all_times.append(time_dt)

#     temp = np.array(temp_profiles).T
#     dens = np.array(dens_profiles).T

#     pc = ax.pcolormesh(valid_times, regular_depth, temp,
#                        cmap=cm.cm.thermal, shading='nearest',
#                        vmin=5.3, vmax=10)

#     for levels, color, lw in [
#             (np.linspace(24, 27, 7), 'black', .51),
#             ([25.6], 'white', 0.5),
#             ([25.7], 'lime', 0.5),
#             ([25.8], 'red', 0.5),
#             ([25.9], 'blue', 0.5),
#             ([26.0], 'black', 0.51)]:
#         cf = ax.contour(valid_times, regular_depth, dens,
#                         levels=levels, colors=color, linewidths=lw)
#         if lw != 0.3:
#             ax.clabel(cf, fmt='%1.2f')

#     ax.text(0.01, 0.95, f"{hakai_station} (Southern Sill)",
#             transform=ax.transAxes,
#             fontsize=26, fontweight="bold",
#             va="top", ha="left")

#     ax.invert_yaxis()
#     ax.set_ylim(ylims[-1])
#     ax.set_xlim(min(all_times), max(all_times))
#     ax.xaxis.set_major_formatter(DateFormatter('%b %d\n%Y'))
#     ax.set_xlabel("Date", fontsize=20)

#     # Top ticks on first axis
#     axes[0].secondary_xaxis('top').set_visible(False)

#     plt.rcParams.update({
#         "xtick.labelsize": 24,
#         "ytick.labelsize": 24,
#         "axes.titlesize": 26,
#         "axes.labelsize": 26,
#         "figure.titlesize": 30
#     })

#     fig.colorbar(pc, ax=axes, label="Temperature (°C)", orientation='vertical')
#     fig.supylabel("Depth (m)", fontsize=30)

# plot_combined_sill_time_series(
#     cube=new_cube,
#     ds_hakai=ds,
#     along_values=[0, -18000],  # North sill, center, south glider
#     hakai_station="QCS07",
#     ylims=[(130, 0), (400, 0)],  # ylims for glider, auto adds Hakai
#     target_years=[2024, 2025],
#     start_month=1
# )

In [None]:
def plot_combined_sill_time_series(
    cube,
    ds_hakai,
    along_values=[0, -18000],          # glider panels (e.g., sill, basin)
    hakai_station="QCS07",
    ylims_glider=[(130, 0), (400, 0)], # y-lims for glider panels (order must match along_values)
    ylim_hakai=(125, 0),               # y-lim for Hakai panel
    target_years=[2024, 2025],
    start_month=1,
    vmin=5.3,
    vmax=10,
    save_path="~/Desktop/along_time_series.pdf"
):
    """
    Make a vertically stacked time–depth temperature plot:
    - Top N panels: glider cuts at specified along-values
    - Bottom panel: Hakai cast time series for `hakai_station`
    Styled to match `plot_temp_time_series_subplots`.
    """

    # ─── Build ylims list (glider + Hakai) and height ratios ─────────────────────
    ylims = list(ylims_glider) + [ylim_hakai]
    max_depths = [yl[0] for yl in ylims]
    max_depth = max(max_depths)
    height_ratios = [d / max_depth for d in max_depths]

    n_glider = len(along_values)
    nrows = n_glider + 1  # +1 for Hakai

    import matplotlib.gridspec as gridspec
    from matplotlib.dates import DateFormatter
    import os
    from datetime import datetime

    # ---- make figure & a 2-column grid (plots | full-height colorbar)
    fig = plt.figure(figsize=(6, 8))
    gs = fig.add_gridspec(
        nrows=nrows, ncols=2,
        height_ratios=height_ratios,
        width_ratios=[1, 0.06],   # ← main column | colorbar column (make 0.09 if you want thicker)
        hspace=0.02,              # vertical spacing between panels
        wspace=0.02               # no horizontal whitespace between plots and cbar
    )

    axes = [fig.add_subplot(gs[i, 0]) for i in range(nrows)]  # all data panels in col 0
    cax   = fig.add_subplot(gs[:, 1])                         # full-length colorbar in col 1
    if nrows == 1:
        axes = [axes]

    # ─── Collect times across ALL panels to set a global xlim ────────────────────
    all_times = []

    # Quick pass to collect prospective glider times
    glider_time_lookup = {}
    for along_value in along_values:
        cube_sel = cube.sel(along=along_value)
        times_this = [datetime.strptime(t[:8], "%Y%m%d") for t in cube_sel.transect.values]
        valid_times = [
            t for t in times_this
            if t.year in target_years and (start_month is None or t >= datetime(target_years[0], start_month, 1))
        ]
        glider_time_lookup[along_value] = valid_times
        all_times.extend(valid_times)

    # Hakai times (after filtering and grouping)
    ds_station = ds_hakai.where(ds_hakai['station'] == hakai_station, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(target_years), drop=True)

    # ─── If nothing to plot, bail gracefully ─────────────────────────────────────
    if len(all_times) == 0 and ds_station.sizes.get('row', 0) == 0:
        print("No valid time data found in glider or Hakai.")
        return

    # ─── Draw GLIDER PANELS (styled like your other func) ────────────────────────
    last_pc = None
    for ax_i, (ax, along_value, ylim) in enumerate(zip(axes[:n_glider], along_values, ylims_glider)):
        cube_sel = cube.sel(along=along_value)
        temp = cube_sel['potential_temperature'].values
        density = cube_sel['potential_density'].values - 1000
        depth = cube_sel['depth'].values
        times = [datetime.strptime(t[:8], "%Y%m%d") for t in cube_sel.transect.values]

        filtered = [
            (t, te, de)
            for t, te, de in zip(times, temp, density)
            if (
                t.year in target_years
                and (start_month is None or t >= datetime(target_years[0], start_month, 1))
                and not np.all(np.isnan(te))
            )
        ]
        if not filtered:
            ax.text(0.5, 0.5, f"No data for along={along_value/1000:.0f} km",
                    ha='center', va='center', transform=ax.transAxes)
            ax.set_ylim(ylim); ax.invert_yaxis()
            continue

        times_data, temp_data, density_data = zip(*sorted(filtered))
        this_times = sorted(set(times_data))

        # grids: depth x time
        temp_grid = np.full((temp_data[0].shape[0], len(this_times)), np.nan)
        density_grid = np.full_like(temp_grid, np.nan)
        for i, t in enumerate(times_data):
            j = this_times.index(t)
            temp_grid[:, j] = temp_data[i]
            density_grid[:, j] = density_data[i]

        pc = ax.pcolormesh(this_times, depth, temp_grid,
                           cmap=cm.cm.thermal, shading='nearest',
                           vmin=vmin, vmax=vmax, rasterized=True)
        last_pc = pc  # for the colorbar later

        # density contours (thin, small labels)
        for levels, color, lw in [
                (np.linspace(24, 27, 7), 'black', .31),
                ([25.6], 'white', 0.3),
                ([25.7], 'lime', 0.3),
                ([25.8], 'red', 0.3),
                ([25.9], 'blue', 0.3),
                ([26.0], 'black', 0.31),
                ([26.1], 'purple', 0.3),
                ([26.2], 'salmon', 0.3),
                ([26.3], 'yellow', 0.3),
                ([26.4], 'cyan', 0.3)]:
            cf_iso = ax.contour(this_times, depth, density_grid, levels=levels, colors=color, linewidths=lw)
            if lw != 0.3:
                ax.clabel(cf_iso, fmt='%1.2f', fontsize=8, inline=False)

        # label: distance from sill
        km_from_sill = (along_value) / 1000
        if abs(km_from_sill) < 0.25:
            label = "0 km (sill)"
        else:
            label = f"{int(round(km_from_sill))} km"

        ax.text(0.01, 0.95, label, transform=ax.transAxes,
                fontsize=9, fontweight="bold", va="top", ha="left")

        # axis styling to match
        ax.invert_yaxis()
        ax.set_ylim(ylim)
        ax.xaxis.set_major_formatter(DateFormatter('%b %d\n%Y'))
        ax.axhline(130, color='k', linestyle='--', linewidth=0.8)

        # y-ticks only on the last glider subplot
        if ax_i != n_glider - 1:
            ax.set_yticklabels([])
            ax.set_ylabel("")
        else:
            ax.set_ylabel("Depth (m)")
        ax.tick_params(axis='x', which='both', bottom=False, labelbottom=False)

        # top ticks at actual sample times on the FIRST panel only (match style)
        if ax_i == 0:
            ax_top = ax.secondary_xaxis('top')
            ax_top.set_xticks(this_times)
            ax_top.set_xticklabels([t.strftime('%b %-d') for t in this_times], rotation=90, fontsize=7)
            ax_top.tick_params(axis='x', direction='out', length=2, width=0.5, pad=1)

        all_times.extend(this_times)

    # ─── HAKAI PANEL (styled similarly) ──────────────────────────────────────────
    ax = axes[-1]
    regular_depth = np.arange(0, 500, 1)
    grouped = ds_station.groupby('time')

    temp_profiles = []
    dens_profiles = []
    hakai_times = []

    for t, group in grouped:
        if group.sizes['row'] < 2:
            continue
        d = group['depth'].values
        T = group['temperature'].values
        S = group['salinity'].values
        sigma0 = gsw.sigma0(S, T)

        temp_profiles.append(np.interp(regular_depth, d, T, left=np.nan, right=np.nan))
        dens_profiles.append(np.interp(regular_depth, d, sigma0, left=np.nan, right=np.nan))
        
        t_dt = pd.Timestamp(t).to_pydatetime()
        hakai_times.append(t_dt)
        all_times.append(t_dt)

    if len(hakai_times) > 0:
        temp = np.array(temp_profiles).T  # depth x time
        dens = np.array(dens_profiles).T

        pc = ax.pcolormesh(hakai_times, regular_depth, temp,
                           cmap=cm.cm.thermal, shading='nearest',
                           vmin=vmin, vmax=vmax, rasterized=True)
        last_pc = pc

        for levels, color, lw in [
                (np.linspace(24, 27, 7), 'black', .31),
                ([25.6], 'white', 0.3),
                ([25.7], 'lime', 0.3),
                ([25.8], 'red', 0.3),
                ([25.9], 'blue', 0.3),
                ([26.0], 'black', 0.31),
                ([26.1], 'purple', 0.3),
                ([26.2], 'salmon', 0.3),
                ([26.3], 'yellow', 0.3),
                ([26.4], 'cyan', 0.3)]:
            cf = ax.contour(hakai_times, regular_depth, dens, levels=levels, colors=color, linewidths=lw)
            if lw != 0.3:
                ax.clabel(cf, fmt='%1.2f', fontsize=8, inline=False)
    else:
        ax.text(0.5, 0.5, f"No Hakai casts for {hakai_station}",
                ha='center', va='center', transform=ax.transAxes)

    ax.text(0.01, 0.9, f"{hakai_station} (southern sill)",
            transform=ax.transAxes, fontsize=9, fontweight="bold",
            va="top", ha="left")
    ax.invert_yaxis()
    ax.set_ylim(ylim_hakai)

    # remove depth labels from QCS07 panel
    ax.set_yticklabels([])
    ax.set_ylabel("")

    ax.tick_params(axis='x', which='both', bottom=False, labelbottom=False)
    ax.set_xlabel("")

    if all_times:
        march_start = datetime(target_years[0], 3, 1)  # March 1 of first target year
        global_xlim = (max(march_start, min(all_times)), max(all_times))
        for ax in axes:
            ax.set_xlim(global_xlim)

    # ─── Matplotlib rcparams for small, clean look ───────────────────────────────
    plt.rcParams.update({
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "axes.titlesize": 12,
        "axes.labelsize": 12,
    })

    # Slim vertical colorbar, similar spacing
    if last_pc is not None:
        cbar = fig.colorbar(last_pc, cax=cax, orientation='vertical')
        cbar.set_label("Temperature (°C)", fontsize=12, labelpad=4)
        cbar.ax.tick_params(labelsize=8)

    if len(hakai_times) > 0:
        ax_top = ax.secondary_xaxis('bottom')
        ax_top.set_xticks(hakai_times)
        ax_top.set_xticklabels([t.strftime('%b %-d') for t in hakai_times],
                               rotation=90, fontsize=7)
        ax_top.tick_params(axis='x', direction='out', length=2, width=0.5, pad=1)

    # fig.text(0.5, 0.02, "Date", ha="center", va="center", fontsize=12)

    fig.savefig(os.path.expanduser(save_path), format='pdf', bbox_inches='tight', dpi=300)

plot_combined_sill_time_series(
    cube=new_cube,
    ds_hakai=ds,
    along_values=[-40000, 0, 18000],
    hakai_station="QCS07",
    ylims_glider=[(200,0), (140,0), (400,0)],
    ylim_hakai=(130, 0), 
    target_years=[2024, 2025],
    start_month=3
)

In [None]:
import xarray as xr
import numpy as np
import pandas as pd

import xarray as xr
import numpy as np
import pandas as pd

def patch_station_with_donor(
    target_ds, donor_ds,
    target_name="FZH08",
    min_depth=300,
    when="latest",
    tol_hours=36
):
    """
    Append one donor cast (re-labeled as target_name) to target_ds for plotting.
    Keeps 'station' as a variable (not coordinate) to avoid concat conflicts.
    """
    d = donor_ds.dropna(dim="row", subset=["time","depth","temperature","salinity"])
    groups = list(d.groupby("time"))

    if not groups:
        print("No valid donor casts found.")
        return target_ds

    def cast_max_depth(g):
        return float(np.nanmax(g["depth"].values)) if g.sizes.get("row", 0) else np.nan

    # Pick donor time
    if when == "latest":
        candidates = [(t, g) for (t, g) in groups if cast_max_depth(g) >= min_depth]
        if not candidates:
            print(f"No donor casts reach >= {min_depth} m.")
            return target_ds
        t_sel, g_sel = max(candidates, key=lambda tg: pd.Timestamp(tg[0]))
    else:
        when_ts = pd.Timestamp(when)
        tol = pd.Timedelta(hours=tol_hours)
        nearest = sorted(groups, key=lambda tg: abs(pd.Timestamp(tg[0]) - when_ts))
        if not nearest:
            print("No donor casts available.")
            return target_ds
        t_sel, g_sel = nearest[0]
        if abs(pd.Timestamp(t_sel) - when_ts) > tol:
            print(f"No donor cast within ±{tol_hours}h of {when_ts}.")
            return target_ds
        if cast_max_depth(g_sel) < min_depth:
            print(f"Nearest donor cast is shallower than {min_depth} m.")
            return target_ds

    donor_cast = donor_ds.where(donor_ds["time"] == np.datetime64(pd.Timestamp(t_sel)), drop=True)

    # Overwrite station *as a variable*, not a coordinate
    donor_cast = donor_cast.drop_vars("station", errors="ignore")
    donor_cast = donor_cast.assign(station=("row", np.array([target_name] * donor_cast.sizes["row"], dtype=object)))

    # Ensure target also has station as a variable
    if "station" in target_ds.coords:
        target_ds = target_ds.reset_coords("station")

    patched = xr.concat([target_ds, donor_cast], dim="row", join="outer", combine_attrs="drop_conflicts")

    print(f"Patched {target_name} with donor cast from HKP04 at {pd.Timestamp(t_sel)} "
          f"(max depth ~{cast_max_depth(g_sel):.0f} m).")
    return patched

# Create a patched FZH08 that includes one deep HKP04 cast
FZH08_patched = patch_station_with_donor(FZH08, HKP04, target_name="FZH08", min_depth=300, when="latest")

hakai_dict = {
    "QCS07": QCS07,
    "FZH08": FZH08_patched,  # <-- patched version
    "HKP04": HKP04,
}

# Now plot normally
plot_combined_sill_time_series_two_hakai(
    cube=new_cube,
    hakai_dict=hakai_dict,
    along_values=[-40000, 0],
    hakai_stations=["QCS07", "FZH08"],   # FZH08 panel will include the borrowed HKP04 cast
    ylims_glider=[(200,0), (140,0)],
    ylims_hakai=[(135,0), (400,0)],
    target_years=[2024, 2025],
    start_month=3
)

# Okay so above works for patching in the HKP04 data but carefulllll cuz there's only one of them and its for the wrong date. waiting to hear back from Drew.


In [None]:
def plot_combined_sill_time_series_two_hakai(
    cube,
    hakai_dict=None,
    along_values=[0, -18000],                 # glider panels (e.g., sill, basin)
    hakai_stations=("QCS07", "FZH08"),        # two Hakai stations
    ylims_glider=[(130, 0), (400, 0)],        # y-lims for glider panels
    ylims_hakai=[(125, 0), (125, 0)],         # one per Hakai station
    target_years=[2024, 2025],
    start_month=1,
    vmin=5.3,
    vmax=10,
    save_path="~/Desktop/along_time_series_dual.pdf",
    station_ll=None                            # optional: {'QCS07': (lat, lon), 'FZH08': (lat, lon)}
):
    """
    Like your original, but with two Hakai panels stacked under the glider panels.
    Uses TEOS-10 to compute σ0 so isopycnals match glider σθ.
    """

    # ─── Build ylims list (glider + 2×Hakai) and height ratios ───────────────────
    ylims = list(ylims_glider) + list(ylims_hakai)
    max_depths = [yl[0] for yl in ylims]
    max_depth = max(max_depths)
    height_ratios = [d / max_depth for d in max_depths]

    n_glider = len(along_values)
    n_hakai  = len(hakai_stations)
    nrows    = n_glider + n_hakai

    import matplotlib.gridspec as gridspec
    from matplotlib.dates import DateFormatter
    import os
    from datetime import datetime

    fig = plt.figure(figsize=(6, 9))
    gs = fig.add_gridspec(
        nrows=nrows, ncols=2,
        height_ratios=height_ratios,
        width_ratios=[1, 0.06],
        hspace=0.02,
        wspace=0.02
    )

    axes = [fig.add_subplot(gs[i, 0]) for i in range(nrows)]
    cax   = fig.add_subplot(gs[:, 1])
    if nrows == 1:
        axes = [axes]

    all_times = []

    # ─── GLIDER PANELS (unchanged style) ─────────────────────────────────────────
    last_pc = None
    for ax_i, (ax, along_value, ylim) in enumerate(zip(axes[:n_glider], along_values, ylims_glider)):
        cube_sel = cube.sel(along=along_value)
        temp = cube_sel['potential_temperature'].values
        density = cube_sel['potential_density'].values - 1000
        depth = cube_sel['depth'].values
        times = [datetime.strptime(t[:8], "%Y%m%d") for t in cube_sel.transect.values]

        filtered = [
            (t, te, de)
            for t, te, de in zip(times, temp, density)
            if (
                t.year in target_years
                and (start_month is None or t >= datetime(target_years[0], start_month, 1))
                and not np.all(np.isnan(te))
            )
        ]
        if not filtered:
            ax.text(0.5, 0.5, f"No data for along={along_value/1000:.0f} km",
                    ha='center', va='center', transform=ax.transAxes)
            ax.set_ylim(ylim); ax.invert_yaxis()
            continue

        times_data, temp_data, density_data = zip(*sorted(filtered))
        this_times = sorted(set(times_data))

        temp_grid = np.full((temp_data[0].shape[0], len(this_times)), np.nan)
        density_grid = np.full_like(temp_grid, np.nan)
        for i, t in enumerate(times_data):
            j = this_times.index(t)
            temp_grid[:, j] = temp_data[i]
            density_grid[:, j] = density_data[i]

        pc = ax.pcolormesh(this_times, depth, temp_grid,
                           cmap=cm.cm.thermal, shading='nearest',
                           vmin=vmin, vmax=vmax, rasterized=True)
        last_pc = pc

        for levels, color, lw in [
                (np.linspace(24, 27, 7), 'black', .31),
                ([25.6], 'white', 0.3),
                ([25.7], 'lime', 0.3),
                ([25.8], 'red', 0.3),
                ([25.9], 'blue', 0.3),
                ([26.0], 'black', 0.31),
                ([26.1], 'purple', 0.3),
                ([26.2], 'salmon', 0.3),
                ([26.3], 'yellow', 0.3),
                ([26.4], 'cyan', 0.3)]:
            cf_iso = ax.contour(this_times, depth, density_grid, levels=levels, colors=color, linewidths=lw)
            if lw != 0.3:
                ax.clabel(cf_iso, fmt='%1.2f', fontsize=8, inline=False)

        km_from_sill = (along_value) / 1000
        label = "0 km (sill)" if abs(km_from_sill) < 0.25 else f"{int(round(km_from_sill))} km"
        ax.text(0.01, 0.95, label, transform=ax.transAxes,
                fontsize=9, fontweight="bold", va="top", ha="left")

        ax.invert_yaxis()
        ax.set_ylim(ylim)
        ax.xaxis.set_major_formatter(DateFormatter('%b %d\n%Y'))
        ax.axhline(130, color='k', linestyle='--', linewidth=0.8)

        if ax_i != n_glider - 1:
            ax.set_yticklabels([])
            ax.set_ylabel("")
        else:
            ax.set_ylabel("Depth (m)")
        ax.tick_params(axis='x', which='both', bottom=False, labelbottom=False)

        if ax_i == 0:
            ax_top = ax.secondary_xaxis('top')
            ax_top.set_xticks(this_times)
            ax_top.set_xticklabels([t.strftime('%b %-d') for t in this_times], rotation=90, fontsize=7)
            ax_top.tick_params(axis='x', direction='out', length=2, width=0.5, pad=1)

        all_times.extend(this_times)

    # ─── HAKAI PANELS (two stations; TEOS-10 σ0 to match glider σθ) ──────────────
    for ax, station, ylim in zip(axes[n_glider:], hakai_stations, ylims_hakai):
        ds_station = hakai_dict[station]
        ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
        ds_station = ds_station.where(ds_station['time'].dt.year.isin(target_years), drop=True)

        regular_depth = np.arange(0, 500, 1)
        grouped = ds_station.groupby('time')

        temp_profiles = []
        dens_profiles = []
        hakai_times = []

        for t, group in grouped:
            if group.sizes['row'] < 2:
                continue

            # vectors
            d  = group['depth'].values.astype(float)          # m (positive down)
            SP = group['salinity'].values.astype(float)       # Practical Salinity
            T  = group['temperature'].values.astype(float)    # in-situ °C

            # lat/lon for TEOS-10 conversions
            if 'latitude' in group and 'longitude' in group:
                lat_use = float(np.nanmean(group['latitude'].values))
                lon_use = float(np.nanmean(group['longitude'].values))
            elif station_ll and station in station_ll:
                lat_use, lon_use = station_ll[station]
            else:
                lat_use, lon_use = 51.0, -128.0  # sensible default for QCS/FHS

            p  = gsw.p_from_z(-d, lat_use)
            SA = gsw.SA_from_SP(SP, p, lon_use, lat_use)
            CT = gsw.CT_from_t(SA, T, p)
            sigma0 = gsw.sigma0(SA, CT)  # kg/m^3 minus 1000

            # interpolate to regular depth grid
            order = np.argsort(d)
            d_sorted = d[order]
            T_sorted = T[order]
            sig_sorted = sigma0[order]

            if np.all(np.isnan(T_sorted)) or len(np.unique(d_sorted)) < 2:
                continue

            temp_profiles.append(np.interp(regular_depth, d_sorted, T_sorted, left=np.nan, right=np.nan))
            dens_profiles.append(np.interp(regular_depth, d_sorted, sig_sorted, left=np.nan, right=np.nan))

            t_dt = pd.Timestamp(t).to_pydatetime()
            hakai_times.append(t_dt)
            all_times.append(t_dt)

        if len(hakai_times) > 0:
            temp = np.array(temp_profiles).T  # depth x time
            dens = np.array(dens_profiles).T

            pc = ax.pcolormesh(hakai_times, regular_depth, temp,
                               cmap=cm.cm.thermal, shading='nearest',
                               vmin=vmin, vmax=vmax, rasterized=True)
            last_pc = pc

            # same contour sets as glider
            for levels, color, lw in [
                    (np.linspace(24, 27, 7), 'black', .31),
                    ([25.6], 'white', 0.3),
                    ([25.7], 'lime', 0.3),
                    ([25.8], 'red', 0.3),
                    ([25.9], 'blue', 0.3),
                    ([26.0], 'black', 0.31),
                    ([26.1], 'purple', 0.3),
                    ([26.2], 'salmon', 0.3),
                    ([26.3], 'yellow', 0.3),
                    ([26.4], 'cyan', 0.3)]:
                cf = ax.contour(hakai_times, regular_depth, dens, levels=levels, colors=color, linewidths=lw)
                if lw != 0.3:
                    ax.clabel(cf, fmt='%1.2f', fontsize=8, inline=False)
        else:
            ax.text(0.5, 0.5, f"No Hakai casts for {station}",
                    ha='center', va='center', transform=ax.transAxes)

        ax.text(0.01, 0.9, f"{station} (Hakai)",
                transform=ax.transAxes, fontsize=9, fontweight="bold",
                va="top", ha="left")

        ax.invert_yaxis()
        ax.set_ylim(ylim)
        ax.set_ylabel("Depth (m)")
        ax.axhline(130, color='k', linestyle='--', linewidth=0.8)
        ax.tick_params(axis='x', which='both',
               bottom=False, top=False,
               labelbottom=False, labeltop=False)

        if len(hakai_times) > 0:
            ax_top = ax.secondary_xaxis('bottom')
            ax_top.set_xticks(hakai_times)
            ax_top.set_xticklabels([t.strftime('%b %-d') for t in hakai_times],
                                rotation=90, fontsize=7)
            ax_top.tick_params(axis='x', direction='out', length=2, width=0.5, pad=1)

        # ─── Global x-limits (like your original) ────────────────────────────────────
    if all_times:
        march_start = datetime(target_years[0], 3, 1)
        global_xlim = (max(march_start, min(all_times)), max(all_times))
        for ax in axes:
            ax.set_xlim(global_xlim)

    # ─── Tidy rcparams and colorbar (unchanged) ──────────────────────────────────
    plt.rcParams.update({
        "xtick.labelsize": 8,
        "ytick.labelsize": 8,
        "axes.titlesize": 12,
        "axes.labelsize": 12,
    })

    if last_pc is not None:
        cbar = fig.colorbar(last_pc, cax=cax, orientation='vertical')
        cbar.set_label("Temperature (°C)", fontsize=12, labelpad=4)
        cbar.ax.tick_params(labelsize=8)

    fig.savefig(os.path.expanduser(save_path), format='pdf', bbox_inches='tight', dpi=300)

hakai_dict = {
    "QCS07": QCS07,
    "FZH08": FZH08,
    "HKP04": HKP04,
}

plot_combined_sill_time_series_two_hakai(
    cube=new_cube,
    hakai_dict=hakai_dict,
    along_values=[-40000, 0],
    hakai_stations=["QCS07", "FZH08"],  
    ylims_glider=[(200,0), (140,0)],
    ylims_hakai=[(135,0), (400,0)],
    target_years=[2024, 2025],
    start_month=3
)

# plot_combined_sill_time_series_two_hakai(
#     cube=new_cube,
#     hakai_dict=hakai_dict,
#     along_values=[-40000, 0],
#     hakai_stations=["QCS07", "HKP04"],  
#     ylims_glider=[(200,0), (140,0)],
#     ylims_hakai=[(135,0), (400,0)],
#     target_years=[2024, 2025],
#     start_month=3
# )

# TS diagram for 0-30km with insets #

In [None]:
def plot_ts_shelf_basin_by_transect(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    xlim=(30.5, 34),
    ylim=(5, 16),
    ncols=4,
    inset_months=[4, 5, 6, 7, 8, 9],
    inset_limits_by_date=None,
):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    from scipy.optimize import root_scalar
    from matplotlib.colors import Normalize
    import os

    plt.rcParams.update({
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "axes.titlesize": 14,
        "axes.labelsize": 12,
        "legend.fontsize": 10,
        "figure.titlesize": 16})


    subset = cube.sel(along=slice(*region_range)).load()
    times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
    selected_idxs = [i for i, t in enumerate(times) if t.year in target_years and t.month in target_months]
    n_panels = len(selected_idxs)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
                            figsize=(4.5 * ncols, 4 * nrows), squeeze=False)

    fig.subplots_adjust(
        wspace=0.1,
        hspace=0.1)
    axes = axes.ravel()

    # Isopycnal grid
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 300), np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    T_label = 6.0
    isopycnal_levels = [23.0, 23.5, 24.0, 24.5, 25.0, 25.5, 25.6, 25.7, 25.8,
                        25.9, 26.0, 26.1, 26.2, 26.3, 26.4, 26.5]
    S_label = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        sol = root_scalar(func, bracket=[28, 35], method='brentq')
        S_label.append(sol.root if sol.converged else np.nan)
    manual_labels = {iso: (s, T_label) for iso, s in zip(isopycnal_levels, S_label)}

    cmap = plt.colormaps['jet']
    norm = Normalize(vmin=region_range[0], vmax=region_range[1])

    for i, idx in enumerate(selected_idxs):
        ax = axes[i]
        ax.set_facecolor('grey')

        for levels, color, lw in [
            (np.linspace(23, 27, 9), 'black', 0.5),
            ([25.6], 'white', 0.5), ([25.7], 'lime', 0.5),
            ([25.8], 'red', 0.5), ([25.9], 'blue', 0.5),
            ([26.0], 'black', 0.5), ([26.1], 'purple', 0.5),
            ([26.2], 'salmon', 0.5), ([26.3], 'yellow', 0.5),
            ([26.4], 'cyan', 0.5)]:
            cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                            colors=color, linewidths=lw, linestyles='--')
            for level in levels:
                if level in manual_labels:
                    ax.clabel(cs, fmt='%1.2f', fontsize=7, inline=False,
                              manual=[manual_labels[level]])

        ds = subset.isel(transect=idx)
        temp = ds['temperature']
        sali = ds['salinity']
        along_vals = ds['along'].values

        for a_idx, along in reversed(list(enumerate(along_vals))):
            T_col = temp.isel(along=a_idx).values
            S_col = sali.isel(along=a_idx).values
            mask = ~np.isnan(T_col) & ~np.isnan(S_col)
            if np.any(mask):
                ax.scatter(S_col[mask], T_col[mask], s=1.25, color=cmap(norm(along)), marker='o')

        ax.set_title(times[idx].strftime('%d %b %Y'))
        if i % ncols == 0:
            ax.set_ylabel("$\\theta$ (\u00b0C)")
        else:
            ax.set_yticklabels([])
        if i // ncols == nrows - 1:
            ax.set_xlabel("Salinity (psu)")
        else:
            ax.set_xticklabels([])
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

        # ─── Add Inset ─────────────────────────────
        if times[idx].month in inset_months:
            inset = ax.inset_axes([0.5, 0.5, 0.48, 0.48])
            inset.set_facecolor('grey')
            inset_xlim_local = (33.3, 33.8)
            inset_ylim_local = (6.0, 7.1)

            date_str = times[idx].strftime('%Y-%m-%d')
            if inset_limits_by_date and date_str in inset_limits_by_date:
                inset_xlim_local = inset_limits_by_date[date_str].get('xlim', inset_xlim_local)
                inset_ylim_local = inset_limits_by_date[date_str].get('ylim', inset_ylim_local)

            for levels, color, lw in [
                (np.linspace(23, 27, 9), 'black', 0.5),
                ([25.6], 'white', 0.5), ([25.7], 'lime', 0.5),
                ([25.8], 'red', 0.5), ([25.9], 'blue', 0.5),
                ([26.0], 'black', 0.5), ([26.1], 'purple', 0.5),
                ([26.2], 'salmon', 0.5), ([26.3], 'yellow', 0.5),
                ([26.4], 'cyan', 0.5)]:
                cs = inset.contour(S_grid, T_grid, sigma, levels=levels,
                                   colors=color, linewidths=lw, linestyles='--')
                for level in levels:
                    if level in manual_labels:
                        inset.clabel(cs, fmt='%1.2f', fontsize=6, inline=False,
                                     manual=[manual_labels[level]])

            for a_idx, along in reversed(list(enumerate(along_vals))):
                T_col = temp.isel(along=a_idx).values
                S_col = sali.isel(along=a_idx).values
                mask = ~np.isnan(T_col) & ~np.isnan(S_col)
                if np.any(mask):
                    inset.scatter(S_col[mask], T_col[mask], s=0.5, color=cmap(norm(along)), marker='o')

            inset.set_xlim(*inset_xlim_local)
            inset.set_ylim(*inset_ylim_local)
            inset.set_xticks([])
            inset.set_yticks([])

            from matplotlib.patches import Rectangle
            box = Rectangle(
                (inset_xlim_local[0], inset_ylim_local[0]),
                inset_xlim_local[1] - inset_xlim_local[0],
                inset_ylim_local[1] - inset_ylim_local[0],
                edgecolor='black', facecolor='none', linestyle='--', linewidth=1)
            ax.add_patch(box)

    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    fig.subplots_adjust(right=0.88, top=0.92)
    cbar_ax = fig.add_axes([0.9, 0.1, 0.02, 0.8])
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax)
    cbar.set_label('Along-transect distance (m)', fontsize=12)
    fig.suptitle("T–S Diagrams Colored by Along Distance", fontsize=24)

    for ax in fig.axes:
        ax.set_rasterization_zorder(1)

plot_ts_shelf_basin_by_transect(
    new_cube,
    region_range=(-25500, 4500),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    xlim=(30.5, 34),
    ylim=(5, 16),
    ncols=4,
    inset_months=[1,2,3,4, 5, 6, 7, 8, 9,10,11,12],
    # inset_limits_by_date={
    #     '2024-03-12': {'xlim': (32.5, 33), 'ylim': (7.5, 8.4)},
    #     '2024-04-16': {'xlim': (32.75, 33.5), 'ylim': (7.3, 8.5)},
    #     '2024-05-10': {'xlim': (32.7, 33.5), 'ylim': (7.2, 8.1)},
    #     '2024-05-16': {'xlim': (32.9, 33.6), 'ylim': (6.8, 7.9)},
    #     '2024-07-17': {'xlim': (33.3, 33.8), 'ylim': (6, 7.1)},
    #     '2024-07-23': {'xlim': (33.3, 33.8), 'ylim': (6, 7.1)},
    #     '2024-11-14': {'xlim': (33, 33.5), 'ylim': (6.6, 7.5)},
    #     '2024-11-19': {'xlim': (33, 33.5), 'ylim': (6.8, 8)},
    #     '2025-01-09': {'xlim': (32.5, 33.25), 'ylim': (6.8, 8.5)},
    #     '2025-01-14': {'xlim': (32.5, 33.25), 'ylim': (7, 8.5)},
    #     '2025-03-12': {'xlim': (32.5, 33.25), 'ylim': (7.3, 8.25)},
    #     '2025-04-30': {'xlim': (33, 33.8), 'ylim': (6.5, 7.5)},
    # }
)

# Density - Oxygen subplots #

In [None]:
def plot_o2_density_by_transect(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    xlim=(0, 350),
    ylim=(27, 23),  # High density at bottom
    ncols=4,
):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    from matplotlib.colors import Normalize
    import os

    plt.rcParams.update({
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "axes.titlesize": 14,
        "axes.labelsize": 12,
        "legend.fontsize": 10,
        "figure.titlesize": 16})

    cube = cube.load()
    subset = cube.sel(along=slice(*region_range))
    times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
    selected_idxs = [i for i, t in enumerate(times)
                     if t.year in target_years and t.month in target_months]

    # ─── Filter only transects with valid data ─────────────────────────────
    valid_idxs = []
    for idx in selected_idxs:
        ds = subset.isel(transect=idx)
        o2 = ds['oxygen_concentration'].values
        sigma = ds['potential_density'].values - 1000
        if np.any(~np.isnan(o2) & ~np.isnan(sigma)):
            valid_idxs.append(idx)

    if not valid_idxs:
        print("No valid transects with data.")
        return

    n_panels = len(valid_idxs)
    nrows = int(np.ceil(n_panels / ncols))
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
                             figsize=(4.5 * ncols, 4 * nrows), squeeze=False)
    axes = axes.ravel()
    fig.subplots_adjust(
        wspace=0.1,
        hspace=0.1)

    cmap = plt.colormaps['jet']
    norm = Normalize(vmin=region_range[0], vmax=region_range[1])

    for i, idx in enumerate(valid_idxs):
        ax = axes[i]
        ax.set_facecolor('grey')

        ds = subset.isel(transect=idx)
        o2 = ds['oxygen_concentration'].values
        sigma = ds['potential_density'].values - 1000
        along_vals = ds['along'].values

        for a_idx, along in reversed(list(enumerate(along_vals))):
            o2_col = o2[:, a_idx]
            sigma_col = sigma[:, a_idx]
            mask = ~np.isnan(o2_col) & ~np.isnan(sigma_col)
            if np.any(mask):
                ax.scatter(o2_col[mask], sigma_col[mask], s=0.25, color=cmap(norm(along)), marker='o')

        ax.set_title(times[idx].strftime('%d %b %Y'))

        if i % ncols == 0:
            ax.set_ylabel("σ₀ (kg/m³)")
        else:
            ax.set_yticklabels([])

        if i // ncols == nrows - 1:
            ax.set_xlabel("Oxygen (µmol/L)")
        else:
            ax.set_xticklabels([])

        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.grid(color = 'black', linestyle = '--', linewidth = 0.5, alpha = 0.7)

    # ─── Remove unused axes ─────────────────────────────
    for j in range(n_panels, len(axes)):
        fig.delaxes(axes[j])

    fig.subplots_adjust(right=0.88, top=0.92)
    cbar_ax = fig.add_axes([0.9, 0.1, 0.02, 0.8])
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax)
    cbar.set_label('Along-transect distance (m)', fontsize=12)

    fig.suptitle("O₂–σ₀", fontsize=24)

    for ax in fig.axes:
        ax.set_rasterization_zorder(1)

plot_o2_density_by_transect(
    cube,
    region_range=(-25500, 4500),
    target_years=[2024, 2025],
    target_months=list(range(1, 13)),
    xlim=(0, 350),
    ylim=(27, 23),
    ncols=3
)

# Histogram side by side #

In [None]:
def plot_ts_and_o2_density_histograms_side_by_side(
    cube,
    region_range=(0, 30000),
    target_years=[2020, 2021, 2022, 2023, 2024, 2025],
    sal_bins=np.arange(28, 34.01, 0.01),
    temp_bins=np.arange(5, 16.01, 0.01),
    dens_bins=np.arange(23, 27.01, 0.01),
    o2_bins=np.arange(0, 400.01, 1),
):
    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib.colors import LogNorm
    from matplotlib.cm import ScalarMappable
    import gsw
    from scipy.optimize import root_scalar

    plt.rcParams.update({
        "xtick.labelsize": 12,
        "ytick.labelsize": 12,
        "axes.titlesize": 16,
        "axes.labelsize": 14,
        "legend.fontsize": 12,
        "figure.titlesize": 18})

    cube = cube.load()
    subset = cube.sel(along=slice(*region_range))
    times = pd.to_datetime([str(t)[:8] for t in subset.transect.values])
    selected_idxs = [i for i, t in enumerate(times) if t.year in target_years]

    all_T, all_S, all_sigma, all_o2 = [], [], [], []

    for idx in selected_idxs:
        ds = subset.isel(transect=idx)
        temp = ds['temperature'].values
        sal = ds['salinity'].values
        sigma = ds['potential_density'].values - 1000
        o2 = ds['oxygen_concentration'].values

        ts_mask = ~np.isnan(temp) & ~np.isnan(sal)
        o2_mask = ~np.isnan(o2) & ~np.isnan(sigma)

        all_T.append(temp[ts_mask])
        all_S.append(sal[ts_mask])
        all_sigma.append(sigma[o2_mask])
        all_o2.append(o2[o2_mask])

    if not all_T or not all_o2:
        print("Missing data.")
        return

    T = np.concatenate(all_T)
    S = np.concatenate(all_S)
    SIGMA = np.concatenate(all_sigma)
    O2 = np.concatenate(all_o2)

    print(f"Total valid T–S points: {len(T)}")
    print(f"Total valid O₂–σ₀ points: {len(O2)}")

    hist_ts, xedges_ts, yedges_ts = np.histogram2d(S, T, bins=[sal_bins, temp_bins])
    hist_o2, xedges_o2, yedges_o2 = np.histogram2d(O2, SIGMA, bins=[o2_bins, dens_bins])

    log_norm = LogNorm(vmin=1, vmax=7500)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6),
                                   gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.15})
    ax1.set_facecolor('grey')
    ax2.set_facecolor('grey')

    pcm1 = ax1.pcolormesh(xedges_ts, yedges_ts, hist_ts.T,
                          cmap='inferno', norm=log_norm, rasterized=True)
    pcm2 = ax2.pcolormesh(xedges_o2, yedges_o2, hist_o2.T,
                          cmap='inferno', norm=log_norm, rasterized=True)

    # ─── Isopycnals on T–S plot ─────────────────────
    S_grid, T_grid = np.meshgrid(np.linspace(28, 34, 300),
                                 np.linspace(5.3, 20, 300))
    sigma = gsw.sigma0(S_grid, T_grid)
    isopycnal_levels = np.arange(20, 27, 0.5)

    cs1 = ax1.contour(S_grid, T_grid, sigma, levels=isopycnal_levels,
                      colors='black', linewidths=0.5, linestyles='--',  alpha = 0.7)

    # ─── Label isopycnals at the bottom ─────────────────────
    T_label = temp_bins[0] + 0.1
    manual_labels = []
    for iso in isopycnal_levels:
        def func(S): return gsw.sigma0(S, T_label) - iso
        try:
            sol = root_scalar(func, bracket=[sal_bins[0], sal_bins[-1]], method='brentq')
            if sol.converged:
                manual_labels.append((sol.root, T_label))
        except ValueError:
            continue

    ax1.clabel(cs1, fmt='%1.1f', fontsize=8, inline=False, manual=manual_labels)
    for o2_val in np.arange(50, o2_bins[-1], 50):
        ax2.axvline(o2_val, color='black', linestyle='--', linewidth=0.5, alpha=0.7)

    # ─── Horizontal lines on O2–density plot ─────────────────────
    for iso in isopycnal_levels:
        ax2.axhline(iso, color='black', linestyle='--', linewidth=0.5, alpha = 0.7)
        ax2.text(x=o2_bins[-1] + 10, y=iso, s=f'{iso:.1f}', color='white',
                 va='center', fontsize=8)

    ax1.set_xlabel(r"Salinity (psu)")
    ax1.set_ylabel(r"$\theta$ (°C)")
    ax2.set_xlabel(r"$O_2$ ($\mu$mol/L)")
    ax2.set_ylabel(r"$\sigma_\theta$ (kg m$^{-3}$)")

    ax1.set_xlim(sal_bins[0], sal_bins[-1])
    ax1.set_ylim(temp_bins[0], temp_bins[-1])
    ax2.set_xlim(o2_bins[0], o2_bins[-1])
    ax2.set_ylim(dens_bins[-1], dens_bins[0])

    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    sm = ScalarMappable(norm=log_norm, cmap='inferno')
    sm.set_array([])
    fig.colorbar(sm, cax=cbar_ax).set_label("Log-scaled Count")

    fig.suptitle(r"$\theta$–$S$ and $O_2$–$\sigma_\theta$ (0–30 km, 2024 and 2025)", fontsize=20)

    print(f"Max T–S bin count: {hist_ts.max()}")
    print(f"Max O₂–σ₀ bin count: {hist_o2.max()}")

plot_ts_and_o2_density_histograms_side_by_side(
    cube,
    region_range=(0, 30000),
    target_years=[2024, 2025],
    sal_bins=np.arange(30, 34.01, 0.01),
    temp_bins=np.arange(5, 16.01, 0.01),
    dens_bins=np.arange(23, 27.01, 0.01),
    o2_bins=np.arange(0, 425.01, 1),
)

# Water at Bella Coola Analysis #

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Path to your CSV
file_path = "~/Desktop/08FB011_QRD_20250812T1746.csv"

# Read CSV, skipping the first 10 lines of metadata
df = pd.read_csv(file_path, skiprows=10)

# Rename columns for easier use
df.columns = ['Date', 'Parameter', 'Value']

# Convert Date to datetime
df['Date'] = pd.to_datetime(df['Date'])

# Convert Value to numeric (removing any stray quotes/commas)
df['Value'] = pd.to_numeric(df['Value'], errors='coerce')

# Drop NaNs
df = df.dropna(subset=['Value'])

# Calculate annual mean discharge
annual_mean = df['Value'].mean()
print(f"Annual mean discharge: {annual_mean:.2f} m³/s")

# Plot time series
plt.figure(figsize=(12, 5))
plt.plot(df['Date'], df['Value'], label='Daily Mean Discharge')
plt.axhline(annual_mean, color='red', linestyle='--', label=f'Annual Mean ({annual_mean:.2f} m³/s)')
plt.xlabel('Date')
plt.ylabel('Discharge (m³/s)')
plt.title('Bella Coola River Above Hammer Creek - Daily Discharge')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
import xarray as xr
import os
even_newer_cube = xr.open_dataset(os.path.expanduser('~/Desktop/dfo-hal1002-20250701_grid_delayed.nc'))
even_newer_cube