In [None]:

import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import warnings
warnings.filterwarnings("ignore") # For those annoying runtime warnings
import cartopy.feature as cfeature
import numpy as np
import os
import cmocean as cm

# Load dataset
file_path = os.path.expanduser('~/Desktop/Summer 2025 Python/Hakai_calvert.nc')
ds = xr.open_dataset(file_path)

# Load the relevant columns into a DataFrame
df_stations = ds[['station', 'latitude', 'longitude']].to_dataframe().dropna()

# Drop duplicates by station name and location
df_stations_unique = df_stations.drop_duplicates(subset=['station', 'latitude', 'longitude'])

# Convert to list of (lat, lon, name)
station_coords = list(zip(
    df_stations_unique['latitude'],
    df_stations_unique['longitude'],
    df_stations_unique['station']))

def plot_general_bathymetry_with_stations(
        topo_file= os.path.expanduser('~/Desktop/Summer 2025 Python/british_columbia_3_msl_2013.nc'),
        lon_bounds=(-134.8, -125.7), lat_bounds=(49.2, 52.0),
        title='Bathymetry Map with Stations', deepest=1500):
    
    # Station names and coordinates
    stations = {
        'FZH08': (51.7537, -127.9805),
        'HP01': (51.7413, -128.011),
        'HP02': (51.7295, -128.0428),
        'HP03': (51.721,  -128.068),
        'HP04': (51.7102, -128.0905),
        'HP05': (51.7053, -128.1192),
        'HP08': (51.7045, -128.1452),
        'HP09': (51.7067, -128.181),
        'HP10': (51.7062, -128.2152),
        'QCS01':(51.7117, -128.2767),
        'CL02': (51.705,  -128.368),
        'CL03': (51.464,  -128.5),
        'CL04': (51.4083, -128.6605),
        'CL05': (51.417,  -128.8237),
        'CL06': (51.3935, -128.8812),
        'CL07': (51.3713, -128.9367),
        'CL08': (51.348,  -128.9937),
        'CL09': (51.2515, -129.3543),
        'CL10': (51.1842, -129.5478),
        'CL11': (51.081,  -129.855),
        'CL12': (51.0167, -130.0),
        'CL13': (50.9187, -130.979),
        'CL14': (50.8493, -131.1492),
        'P16':  (49.2833, -134.6667),
        'HAK1': (51.71196,-128.23608 )
    }

    buoys = {'C46204': (51.380, -128.770)}

    # Load and subset bathymetry
    topo = xr.open_dataset(topo_file)
    topo = topo.sel(
        lon=slice(lon_bounds[0], lon_bounds[1]),
        lat=slice(lat_bounds[0], lat_bounds[1]))
    depth = -topo['Band1']

    # Set up figure
    fig, ax = plt.subplots(figsize=(12, 9), subplot_kw={'projection': ccrs.PlateCarree()})
    ax.set_extent([*lon_bounds, *lat_bounds], crs=ccrs.PlateCarree())

    # Gridlines
    gl = ax.gridlines(draw_labels=True, linestyle='--', alpha=0.3)
    gl.top_labels = False
    gl.right_labels = False

    # Bathymetry
    levels = np.linspace(0, deepest, 22)
    contourf = ax.contourf(topo['lon'], topo['lat'], depth,
                           levels=levels, cmap=cm.cm.deep, extend='both')

    contours = ax.contour(topo['lon'], topo['lat'], depth,
                          levels=np.linspace(0, deepest, 22),
                          colors='k', linewidths=0.3)
    ax.clabel(contours, fmt='%d', fontsize=3)

    plt.colorbar(contourf, ax=ax, label='Depth (m)')

    # Plot stations that fall within bounds
    for name, (lat, lon) in stations.items():
        if lat_bounds[0] <= lat <= lat_bounds[1] and lon_bounds[0] <= lon <= lon_bounds[1]:
            ax.plot(lon, lat, marker='o', color='red', markersize=4, transform=ccrs.PlateCarree())
            ax.text(lon + 0.02, lat, name, fontsize=7, transform=ccrs.PlateCarree())

    # Plot buoys that fall within bounds
    for name, (lat, lon) in buoys.items():
        if lat_bounds[0] <= lat <= lat_bounds[1] and lon_bounds[0] <= lon <= lon_bounds[1]:
            ax.plot(lon, lat, marker='o', color='blue', markersize=8, transform=ccrs.PlateCarree())
            ax.text(lon + 0.02, lat, name, fontsize=7, transform=ccrs.PlateCarree())
        # ─── Plot Custom Transect ───────────────────────────────────
    transect_stations = ["HKP04", "FZH04", "BUR1", "FC1", "FC2", "FC3", 
                         "DE1", "DE2", "DE3", "DE4", "DE5", "DE6",
                         "BUR8", "BUR7", "BUR6", "BUR5", "BUR4", "BUR3"]

    # Filter unique station dataframe to get coordinates in correct order
    df_transect = df_stations_unique[df_stations_unique['station'].isin(transect_stations)]
    df_transect = df_transect.set_index('station').loc[transect_stations].reset_index()

    ax.plot(df_transect['longitude'], df_transect['latitude'],
            marker='o', markersize=5, color='red', linestyle='--',
            linewidth=1.5, transform=ccrs.PlateCarree(), zorder=3)

    for i, row in df_transect.iterrows():
        ax.text(row['longitude'] + 0.015, row['latitude'],
                row['station'], fontsize=7, transform=ccrs.PlateCarree(), zorder=4)
        
    ax.set_title(title)
    ax.set_xlabel('Longitude')
    ax.set_ylabel('Latitude')
    ax.set_aspect(1 / np.cos(np.deg2rad(np.mean(lat_bounds))))

%matplotlib widget

plot_general_bathymetry_with_stations(
    lon_bounds=(-128.5, -126.5),
    lat_bounds=(51, 52.5),
     deepest= 700)

In [None]:
if False:
   # ─── Load Hakai Dataset ──────────────────────────────────────────────────────
    file_path = os.path.expanduser('~/Desktop/Hakai_calvert.nc')
    ds = xr.open_dataset(file_path)

    # Convert to DataFrame immediately to avoid xarray boolean indexing issues
    df = ds.to_dataframe().reset_index()

    # ─── Extract Unique Station Coordinates ──────────────────────────────────────
    df_stations = df[['station', 'latitude', 'longitude']].dropna()
    df_stations_unique = df_stations.drop_duplicates(subset=['station', 'latitude', 'longitude'])
    station_coords = list(zip(
        df_stations_unique['latitude'],
        df_stations_unique['longitude'],
        df_stations_unique['station']))

file_path = os.path.expanduser('~/Desktop/Summer 2025 Python/Hakai_calvert.nc')
ds = xr.open_dataset(file_path)

# Drop any points missing lat/lon or station info
ds_stations = ds.dropna(dim='row', subset=['station', 'latitude', 'longitude'])

# Find unique stations
_, index = np.unique(ds_stations['station'].values, return_index=True)
station_coords = list(zip(
    ds_stations['latitude'].values[index],
    ds_stations['longitude'].values[index],
    ds_stations['station'].values[index]
))

In [None]:
np.unique(ds['station'].values)

In [None]:
%matplotlib widget
def plot_bathymetry_with_stations(
    topo,
    station_coords,
    selected_station_names=None,
    lon_bounds=(-128.5, -127.5),
    lat_bounds=(51, 52.5),
    deepest=500
):
    """
    Plot bathymetry and stations from data.
    
    Parameters:
    - topo: xarray.Dataset with bathymetry
    - station_coords: list of (lat, lon, name)
    - selected_station_names: optional list of station names to plot (if None, plot all)
    - lon_bounds: longitude bounds
    - lat_bounds: latitude bounds
    - deepest: deepest contour level
    """
    import matplotlib.pyplot as plt
    import cartopy.crs as ccrs
    import cmocean as cm
    import numpy as np

    # Select region
    topo_sel = topo.sel(
        lon=slice(lon_bounds[0], lon_bounds[1]),
        lat=slice(lat_bounds[0], lat_bounds[1])
    )
    depth = -topo_sel['Band1']

    # Plot
    fig, ax = plt.subplots(figsize=(12, 9), subplot_kw={'projection': ccrs.PlateCarree()})
    ax.set_extent([*lon_bounds, *lat_bounds])

    levels = np.linspace(0, deepest, 22)
    cf = ax.contourf(
        topo_sel['lon'], topo_sel['lat'], depth,
        levels=levels, cmap=cm.cm.deep, extend='both'
    )
    contours = ax.contour(
        topo_sel['lon'], topo_sel['lat'], depth,
        levels=levels, colors='k', linewidths=0.3
    )
    ax.clabel(contours, fmt='%d', fontsize=4)
    plt.colorbar(cf, ax=ax, label='Depth (m)')

    # Plot stations
    for lat, lon, name in station_coords:
        if (selected_station_names is None or name in selected_station_names):
            if lat_bounds[0] <= lat <= lat_bounds[1] and lon_bounds[0] <= lon <= lon_bounds[1]:
                ax.plot(
                    lon, lat,
                    marker='o', color='red', markersize=6,
                    transform=ccrs.PlateCarree()
                )
                ax.text(
                    lon + 0.01, lat,
                    str(name), fontsize=9,
                    transform=ccrs.PlateCarree()
                )

    ax.set_xlabel("Longitude")
    ax.set_ylabel("Latitude")

    plt.tight_layout()

topo_file= os.path.expanduser('~/Desktop/Summer 2025 Python/british_columbia_3_msl_2013.nc')
topo = xr.open_dataset(topo_file)

# # Example function call
# plot_bathymetry_with_stations(
#     topo=topo,
#     station_coords=station_coords,
#     selected_station_names=[
#     "DE1",
#     "DE2",
#     "DE3",
#     "DFO2",
#     "FC1",
#     "FC2",
#     "FZH01",
#     "FZH04",
#     "FZH08",
#     "FZH10",
#     "FZH13",
#     "FZH14",
#     "HKP01",
#     "HKP03",
#     "HKP04",
#     "HKP05",
#     "HKP06",
#     "KC10",
#     "QCS01",
#     "QCS02",
#     "QCS03",
#     "QCS04",
#     "QCS05",
#     "QCS06",
#     "QCS07",
#     "QCS08"],
#     lon_bounds=(-128.1, -127.8),
#     lat_bounds=(51.7, 51.8),
#     deepest=500)

plot_bathymetry_with_stations(topo=topo,station_coords=station_coords, deepest=500, lon_bounds= (-129, -127), lat_bounds= (51, 53))

In [None]:
# plot_station_timeseries("BUR1", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("BUR2", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("BUR3", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("BUR4", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("BUR5", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("BUR6", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("BUR7", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("BUR8", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))

# plot_station_timeseries("DE1", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("DE2", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("DE3", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("DE4", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("DE5", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("DE6", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("DE7", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))

plot_station_timeseries("BUR1", ds, years = [2023, 2024, 2025], ylim = (500, 0))
plot_station_timeseries("BUR2", ds, years = [2023, 2024, 2025], ylim = (500, 0))
plot_station_timeseries("BUR3", ds, years = [2023, 2024, 2025], ylim = (500, 0))
plot_station_timeseries("BUR4", ds, years = [2023, 2024, 2025], ylim = (500, 0))
plot_station_timeseries("BUR5", ds, years = [2023, 2024, 2025], ylim = (500, 0))
plot_station_timeseries("BUR6", ds, years = [2023, 2024, 2025], ylim = (500, 0))
plot_station_timeseries("BUR7", ds, years = [2023, 2024, 2025], ylim = (500, 0))
plot_station_timeseries("BUR8", ds, years = [2023, 2024, 2025], ylim = (500, 0))

In [None]:
plot_station_timeseries("DE1", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0))

In [None]:
plot_station_timeseries("FZH08", ds, years = [2023, 2024, 2025], ylim = (500, 0))
plot_station_timeseries("HKP04", ds, years = [2023, 2024, 2025], ylim = (500, 0))

In [None]:
def plot_section_grid_by_month(ds, station_order, target_year=2023, depth_max=500, ncols=4):
    import numpy as np
    import matplotlib.pyplot as plt
    import cmocean as cm
    import gsw
    import pandas as pd
    from pyproj import Geod

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})
    
    geod = Geod(ellps="WGS84")
    regular_depth = np.arange(0, depth_max, 1)

    nrows = int(np.ceil(12 / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 6.4, nrows * 4.8), sharex=True, sharey=True)
    axes = axes.flatten()

    for month in range(1, 13):
        temp_profiles = []
        dens_profiles = []
        along_coords = []
        sample_days = []
        station_labels = []

        ref_lat = None
        ref_lon = None
        cumulative_dist = 0

        for sid in station_order:
            ds_station = ds.where(ds['station'] == sid, drop=True)
            ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
            ds_station = ds_station.where(
                (ds_station['time'].dt.year == target_year) & 
                (ds_station['time'].dt.month == month), drop=True)

            if ds_station.sizes['row'] < 2:
                continue

            grouped = ds_station.groupby('time')

            for t, group in grouped:
                if group.sizes['row'] < 2:
                    continue

                depth_vals = group['depth'].values
                temp_vals = group['temperature'].values
                sal_vals = group['salinity'].values
                dens_vals = gsw.sigma0(sal_vals, temp_vals)

                temp_interp = np.interp(regular_depth, depth_vals, temp_vals, left=np.nan, right=np.nan)
                dens_interp = np.interp(regular_depth, depth_vals, dens_vals, left=np.nan, right=np.nan)

                lat = group['latitude'].values[0]
                lon = group['longitude'].values[0]

                if ref_lat is None:
                    ref_lat, ref_lon = lat, lon
                    dist_km = 0
                else:
                    _, _, dist_m = geod.inv(ref_lon, ref_lat, lon, lat)
                    cumulative_dist += dist_m / 1000
                    dist_km = cumulative_dist
                    ref_lat, ref_lon = lat, lon

                temp_profiles.append(temp_interp)
                dens_profiles.append(dens_interp)
                along_coords.append(dist_km)
                sample_days.append(pd.to_datetime(t).strftime('%b %d'))
                station_labels.append(sid)

        ax = axes[month - 1]
        if len(temp_profiles) < 2 or len(np.unique(along_coords)) < 2:
            ax.set_title(pd.Timestamp(target_year, month, 1).strftime('%B'), pad=10)
            ax.text(0.5, 0.5, 'No data', transform=ax.transAxes, ha='center', va='center')
            continue

        temp = np.array(temp_profiles).T
        dens = np.array(dens_profiles).T
        along = np.array(along_coords)

        pc = ax.pcolormesh(along, regular_depth, temp, shading='nearest',
                           cmap=cm.cm.thermal, vmin=5.3, vmax=10)

        for levels, color, lw in [
            (np.linspace(24, 27, 7), 'black', 0.5),
            ([25.6], 'white', 0.5),
            ([25.7], 'lime', 0.5),
            ([25.8], 'red', 0.5),
            ([25.9], 'blue', 0.5),
            ([26.0], 'black', 0.5),
            ([26.1], 'purple', 0.5),
            ([26.2], 'salmon', 0.5),
            ([26.3], 'yellow', 0.5),
            ([26.4], 'cyan', 0.5)]:
            iso = ax.contour(along, regular_depth, dens, levels=levels, colors=color, linewidths=lw)
            ax.clabel(iso, fmt='%1.2f', fontsize=6)

        for x, label, day in zip(along, station_labels, sample_days):
            ax.text(x, -10, f"{label}\n{day}", ha='center', va='bottom', fontsize=6, rotation=90)

        ax.set_title(pd.Timestamp(target_year, month, 1).strftime('%B'))
    ax.invert_yaxis()

    # Hide unused subplots
    for j in range(12, len(axes)):
        fig.delaxes(axes[j])

    fig.suptitle(f"Monthly Temperature Sections — {target_year}", fontsize=28)
    fig.text(0.5, 0.01, 'Distance along transect (km)', ha='center', fontsize=18)
    fig.text(0.01, 0.5, 'Depth (m)', va='center', rotation='vertical', fontsize=18)
    cbar_ax = fig.add_axes([0.92, 0.15, 0.015, 0.7])
    fig.colorbar(pc, cax=cbar_ax, label='Temperature (°C)')
    plt.tight_layout(rect=[0.03, 0.04, 0.9, 0.94])

station_order = ["FZH08", "BUR1", "FC1", "FC2", "FC3", 
                 "DE1", "DE2", "DE3", "DE4", "DE5", "DE6",
                 "BUR8", "BUR7", "BUR6", "BUR5", "BUR4", "BUR3"]

In [None]:
def plot_oxygen_section_grid_by_month(ds, station_order, target_year=2023, depth_max=500, ncols=3, oxymax=160):
    import numpy as np
    import matplotlib.pyplot as plt
    import cmocean as cm
    import gsw
    import pandas as pd
    from pyproj import Geod
    from matplotlib.colors import TwoSlopeNorm

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 16,
        'axes.labelsize': 16,
        'xtick.labelsize': 10,
        'ytick.labelsize': 12,
        'legend.fontsize': 16,
        'figure.titlesize': 22})

    geod = Geod(ellps="WGS84")
    regular_depth = np.arange(0, depth_max, 1)

    nrows = int(np.ceil(12 / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 6.4, nrows * 4.8), sharex=True, sharey=True)
    axes = axes.flatten()

    for month in range(1, 13):
        temp_profiles = []
        oxy_profiles = []
        along_coords = []
        sample_days = []
        station_labels = []

        ref_lat = None
        ref_lon = None
        cumulative_dist = 0

        for sid in station_order:
            ds_station = ds.where(ds['station'] == sid, drop=True)
            ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time', 'dissolved_oxygen_ml_l'])
            ds_station = ds_station.where(
                (ds_station['time'].dt.year == target_year) &
                (ds_station['time'].dt.month == month), drop=True)

            if ds_station.sizes['row'] < 2:
                continue

            grouped = ds_station.groupby('time')

            for t, group in grouped:
                if group.sizes['row'] < 2:
                    continue

                depth_vals = group['depth'].values
                temp_vals = group['temperature'].values
                sal_vals = group['salinity'].values
                oxy_vals = group['dissolved_oxygen_ml_l'].values * 44.661  # ml/L to umol/kg

                lat = group['latitude'].values[0]
                lon = group['longitude'].values[0]

                if ref_lat is None:
                    ref_lat, ref_lon = lat, lon
                    dist_km = 0
                else:
                    _, _, dist_m = geod.inv(ref_lon, ref_lat, lon, lat)
                    cumulative_dist += dist_m / 1000
                    dist_km = cumulative_dist
                    ref_lat, ref_lon = lat, lon

                oxy_interp = np.interp(regular_depth, depth_vals, oxy_vals, left=np.nan, right=np.nan)
                oxy_profiles.append(oxy_interp)
                along_coords.append(dist_km)
                sample_days.append(pd.to_datetime(t).strftime('%b %d'))
                station_labels.append(sid)

        ax = axes[month - 1]
        if len(oxy_profiles) < 2 or len(np.unique(along_coords)) < 2:
            ax.set_title(pd.Timestamp(target_year, month, 1).strftime('%B'), pad=10)
            ax.text(0.5, 0.5, 'No data', transform=ax.transAxes, ha='center', va='center')
            continue

        oxy = np.array(oxy_profiles).T
        along = np.array(along_coords)

        norm = TwoSlopeNorm(vmin=0, vcenter=60, vmax=oxymax)
        pc = ax.pcolormesh(along, regular_depth, oxy, shading='nearest',
                           cmap=cm.cm.balance_r, norm=norm)

        # Black contours every 20 except 60 (red)
        base_levels = np.arange(0, oxymax + 1, 20)
        other_levels = [lvl for lvl in base_levels if not np.isclose(lvl, 60)]
        if other_levels:
            cs_black = ax.contour(along, regular_depth, oxy, levels=other_levels,
                                  colors='black', linewidths=0.5)
            ax.clabel(cs_black, fmt='%1.0f', fontsize=7)
        cs_red = ax.contour(along, regular_depth, oxy, levels=[60], colors='red', linewidths=1.5)
        ax.clabel(cs_red, fmt='%1.0f', fontsize=8)

        for x, label, day in zip(along, station_labels, sample_days):
            ax.text(x, -10, f"{label}\n{day}", ha='center', va='bottom', fontsize=6, rotation=90)

        ax.set_title(pd.Timestamp(target_year, month, 1).strftime('%B'), pad=10)
        ax.invert_yaxis()

    for j in range(12, len(axes)):
        fig.delaxes(axes[j])

    fig.suptitle(f"Monthly Oxygen Sections — {target_year}", fontsize=28)
    fig.text(0.5, 0.01, 'Distance along transect (km)', ha='center', fontsize=18)
    fig.text(0.01, 0.5, 'Depth (m)', va='center', rotation='vertical', fontsize=18)
    cbar_ax = fig.add_axes([0.92, 0.15, 0.015, 0.7])
    fig.colorbar(pc, cax=cbar_ax, label='Oxygen (μmol/kg)')
    plt.tight_layout(rect=[0.03, 0.04, 0.9, 0.94])

In [None]:

# plot_station_timeseries("DE1", ds, years = [2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("DE2", ds, years = [2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("DE3", ds, years = [2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("DE4", ds, years = [2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("DE5", ds, years = [2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("DE6", ds, years = [2023, 2024, 2025], ylim = (500, 0))
# plot_station_timeseries("DE7", ds, years = [2023, 2024, 2025], ylim = (500, 0))

In [None]:
def plot_section_by_stations(ds, station_order, target_year=2024, target_month=5, depth_max=400):
    import numpy as np
    import matplotlib.pyplot as plt
    import cmocean as cm
    import gsw
    import pandas as pd
    from pyproj import Geod

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 18,
        'xtick.labelsize': 12,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 22})

    geod = Geod(ellps="WGS84")
    regular_depth = np.arange(0, depth_max, 1)

    temp_profiles = []
    dens_profiles = []
    along_coords = []
    sample_days = []
    station_labels = []

    ref_lat = None
    ref_lon = None
    cumulative_dist = 0

    for sid in station_order:
        ds_station = ds.where(ds['station'] == sid, drop=True)
        ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
        ds_station = ds_station.where(
            (ds_station['time'].dt.year == target_year) & 
            (ds_station['time'].dt.month == target_month), drop=True)

        if ds_station.sizes['row'] < 2:
            continue

        grouped = ds_station.groupby('time')

        for t, group in grouped:
            if group.sizes['row'] < 2:
                continue

            depth_vals = group['depth'].values
            temp_vals = group['temperature'].values
            sal_vals = group['salinity'].values
            dens_vals = gsw.sigma0(sal_vals, temp_vals)

            temp_interp = np.interp(regular_depth, depth_vals, temp_vals, left=np.nan, right=np.nan)
            dens_interp = np.interp(regular_depth, depth_vals, dens_vals, left=np.nan, right=np.nan)

            lat = group['latitude'].values[0]
            lon = group['longitude'].values[0]

            if ref_lat is None:
                ref_lat, ref_lon = lat, lon
                dist_km = 0
            else:
                _, _, dist_m = geod.inv(ref_lon, ref_lat, lon, lat)
                cumulative_dist += dist_m / 1000
                dist_km = cumulative_dist
                ref_lat, ref_lon = lat, lon

            temp_profiles.append(temp_interp)
            dens_profiles.append(dens_interp)
            along_coords.append(dist_km)
            sample_days.append(pd.to_datetime(t).strftime('%b %d'))
            station_labels.append(sid)

    if not temp_profiles:
        print("No valid profiles found.")
        return

    temp = np.array(temp_profiles).T
    dens = np.array(dens_profiles).T
    along = np.array(along_coords)

    # Plotting
    fig, ax = plt.subplots(figsize=(14, 6))
    pc = ax.pcolormesh(along, regular_depth, temp, shading='nearest',
                       cmap=cm.cm.thermal, vmin=5.3, vmax=10)

    for levels, color, lw in [
        (np.linspace(24, 27, 7), 'black', 0.5),
        ([25.6], 'white', 0.5),
        ([25.7], 'lime', 0.5),
        ([25.8], 'red', 0.5),
        ([25.9], 'blue', 0.5),
        ([26.0], 'black', 0.5),
        ([26.1], 'purple', 0.5),
        ([26.2], 'salmon', 0.5),
        ([26.3], 'yellow', 0.5),
        ([26.4], 'cyan', 0.5)]:
        iso = ax.contour(along, regular_depth, dens, levels=levels, colors=color, linewidths=lw)
        ax.clabel(iso, fmt='%1.2f', fontsize=6)

    for x, label, day in zip(along, station_labels, sample_days):
        ax.text(x, -10, f"{label}\n{day}", ha='center', va='bottom', fontsize=10, rotation=90)

    ax.invert_yaxis()
    ax.set_xlabel("Along-Transect Distance (km)")
    ax.set_ylabel("Depth (m)")
    ax.set_title(f"Temperature Section — {pd.Timestamp(target_year, target_month, 1).strftime('%B %Y')}")
    cbar = fig.colorbar(pc, ax=ax)
    cbar.set_label("Temperature (°C)")

station_order = ["FZH08", "BUR1", "FC1", "FC2", "FC3", 
                 "DE1", "DE2", "DE3", "DE4", "DE5", "DE6",
                 "BUR8", "BUR7", "BUR6", "BUR5", "BUR4", "BUR3"]
plot_section_by_stations(ds, station_order, target_year=2024, target_month=5, depth_max=500)

In [None]:
plot_oxygen_timeseries("BUR1", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0), show_density_contours=False)
plot_oxygen_timeseries("BUR2", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0), show_density_contours=False)
plot_oxygen_timeseries("BUR3", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0), show_density_contours=False)
plot_oxygen_timeseries("BUR4", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0), show_density_contours=False)
plot_oxygen_timeseries("BUR5", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0), show_density_contours=False)
plot_oxygen_timeseries("BUR6", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0), show_density_contours=False)
plot_oxygen_timeseries("BUR7", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0), show_density_contours=False)
plot_oxygen_timeseries("BUR8", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0), show_density_contours=False)

plot_oxygen_timeseries("DE1", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0), show_density_contours=False)
plot_oxygen_timeseries("DE2", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0), show_density_contours=False)
plot_oxygen_timeseries("DE3", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0), show_density_contours=False)
plot_oxygen_timeseries("DE4", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0), show_density_contours=False)
plot_oxygen_timeseries("DE5", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0), show_density_contours=False)
plot_oxygen_timeseries("DE6", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0), show_density_contours=False)
plot_oxygen_timeseries("DE7", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (500, 0), show_density_contours=False)

In [None]:
plot_oxygen_timeseries("BUR3", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (320, 0), show_density_contours=False)
plot_station_timeseries("BUR3", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (320, 0))

In [None]:
def plot_along_path(
    cube,
    topo,
    transect_name,
    lon_bounds=(-128.5, -127.5),
    lat_bounds=(51, 52.5),
    deepest=500,
    marker_alongs=(3000, 4000, 5000, 7000, 9500, 23500)
):
    """
    Plot a single transect path colored by along distance, with markers.

    Parameters:
    - cube: xarray.Dataset
    - topo: xarray.Dataset with bathymetry
    - transect_name: transect to plot
    - lon_bounds: longitude bounds
    - lat_bounds: latitude bounds
    - deepest: max depth contour
    - marker_alongs: tuple of along distances (meters) to mark
    """
    import matplotlib.pyplot as plt
    import cartopy.crs as ccrs
    import cmocean
    import numpy as np
    import pandas as pd

    # Select bathymetry region
    topo_sel = topo.sel(
        lon=slice(lon_bounds[0], lon_bounds[1]),
        lat=slice(lat_bounds[0], lat_bounds[1])
    )
    depth = -topo_sel['Band1']

    # Select transect
    ds = cube.sel(transect=transect_name)
    lon = ds['longitude'].values
    lat = ds['latitude'].values
    along = ds['along'].values / 1000  # km

    # Plot
    fig, ax = plt.subplots(figsize=(12, 9), subplot_kw={'projection': ccrs.PlateCarree()})
    ax.set_extent([*lon_bounds, *lat_bounds])

    # Bathymetry
    levels = np.linspace(0, deepest, 22)
    cf = ax.contourf(
        topo_sel['lon'], topo_sel['lat'], depth,
        levels=levels, cmap=cmocean.cm.deep, extend='both'
    )
    contours = ax.contour(
        topo_sel['lon'], topo_sel['lat'], depth,
        levels=levels, colors='k', linewidths=0.3
    )
    ax.clabel(contours, fmt='%d', fontsize=4)
    plt.colorbar(cf, ax=ax, label='Depth (m)')

    # Scatter path
    sc = ax.scatter(
        lon, lat, c=along, s=12, cmap='viridis_r',
        transform=ccrs.PlateCarree()
    )

    # Label start and end
    ax.text(
        lon[0]+0.0001, lat[0], f"Start: {along[0]:.1f} km",
        fontsize=9, transform=ccrs.PlateCarree()
    )
    ax.text(
        lon[-1]+0.0001, lat[-1], f"End: {along[-1]:.1f} km",
        fontsize=9, transform=ccrs.PlateCarree()
    )

    # Markers at specified along positions
    marker_alongs_km = [x/1000 for x in marker_alongs]
    for target_along in marker_alongs_km:
        # Find index closest to desired along
        idx = (np.abs(along - target_along)).argmin()
        ax.plot(
            lon[idx], lat[idx],
            marker='o', color='red', markersize=8,
            transform=ccrs.PlateCarree()
        )
        ax.text(
            lon[idx]+0.0001, lat[idx],
            f"{target_along*1000:.0f} m",
            fontsize=9, color='red',
            transform=ccrs.PlateCarree()
        )

    # Colorbar
    cbar = plt.colorbar(sc, ax=ax, label="Along Distance (km)")

    # Title
    date_str = pd.to_datetime(transect_name[:8], format='%Y%m%d').strftime('%B %d, %Y')
    ax.set_title(f"Glider Path - {date_str}")

    ax.set_xlabel("Longitude")
    ax.set_ylabel("Latitude")

    plt.tight_layout()

cube = xr.open_dataset(os.path.expanduser('~/Desktop/Summer 2025 Python/calvert_cube.nc'))
plot_along_path(
    cube=cube,
    topo=topo,
    transect_name='20241119_out',
    lon_bounds=(-128.5, -127.5),
    lat_bounds=(51, 52.5),
    deepest=500)

plot_along_path(
    cube=cube,
    topo=topo,
    transect_name='20241114_return',
    lon_bounds=(-128.5, -127.5),
    lat_bounds=(51, 52.5),
    deepest=500)

In [None]:
def plot_station_timeseries(station_id, ds, ylim=(500, 0), years=[2023, 2024, 2025]):
    import matplotlib.pyplot as plt
    from matplotlib.dates import DateFormatter
    import numpy as np
    import cmocean as cm
    import gsw
    import pandas as pd

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})

    # Filter to the desired station
    ds_station = ds.where(ds['station'] == station_id, drop=True)

    # Drop rows missing required variables
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])

    # Filter by year
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)

    if ds_station.sizes['row'] < 2:
        print(f"No valid data for station {station_id} in years {years}.")
        return

    # Group by time
    grouped = ds_station.groupby('time')

    # Prepare interpolation grid
    regular_depth = np.arange(0, 500, 1)

    temp_profiles = []
    dens_profiles = []
    valid_times = []

    for t, group in grouped:
        if group.sizes['row'] < 2:
            continue
        depth_vals = group['depth'].values
        temp_vals = group['temperature'].values
        sal_vals = group['salinity'].values

        # Compute sigma_theta on-the-fly for this profile
        dens_vals = gsw.sigma0(sal_vals, temp_vals)

        # Interpolate
        temp_interp = np.interp(regular_depth, depth_vals, temp_vals, left=np.nan, right=np.nan)
        dens_interp = np.interp(regular_depth, depth_vals, dens_vals, left=np.nan, right=np.nan)

        temp_profiles.append(temp_interp)
        dens_profiles.append(dens_interp)
        valid_times.append(t)

    if not temp_profiles:
        print(f"No usable profiles at station {station_id}")
        return

    temp = np.array(temp_profiles).T
    density = np.array(dens_profiles).T
    times = np.array(valid_times).astype('datetime64[ns]')

    # Plotting
    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    pc = ax.pcolormesh(times, regular_depth, temp, cmap=cm.cm.thermal, shading='nearest', vmin=5.3, vmax=10)

    for levels, color, lw in [
            (np.linspace(24, 27, 7), 'black', .51),
                        ([25.6], 'white', 0.5),
                        ([25.7], 'lime', 0.5),
                        ([25.8], 'red', 0.5),
                        ([25.9], 'blue', 0.5),
                        ([26.0], 'black', 0.51),
                        ([26.1], 'purple', 0.5),
                        ([26.2], 'salmon', 0.5),
                        ([26.3], 'yellow', 0.5),
                        ([26.4], 'cyan', 0.5)]:
        cf_iso = ax.contour(times, regular_depth, density, levels=levels, colors=color, linewidths=lw)
        if lw != 0.3:
            ax.clabel(cf_iso, fmt='%1.2f')

    ax.plot(times, [130] * len(times), color='black', linewidth=2, linestyle='--')
    ax.plot(times, [120] * len(times), color='black', linewidth=2, linestyle='--')

    ax.set_title(f"Temperature Time Series at {station_id}")
    ax.set_xlabel("Date")
    ax.set_ylabel("Depth (m)")
    ax.invert_yaxis()
    ax.set_ylim(ylim)
    fig.colorbar(pc, ax=ax, label="Temperature (°C)")
    ax.xaxis.set_major_formatter(DateFormatter('%b %d\n%Y'))

    ax_top = ax.secondary_xaxis('top')
    ax_top.set_xticks(times)
    ax_top.set_xticklabels([pd.Timestamp(t).strftime('%H:%M\n%d %b') for t in times], rotation=90, fontsize=8)
    ax_top.tick_params(axis='x', direction='out', length=3, width=0.5, pad=2)

    plt.tight_layout()

plot_station_timeseries("FZH08", ds, years = [2024, 2025], ylim = (400, 0))
plot_station_timeseries("QCS07", ds, years = [2024, 2025], ylim = (120, 0))

In [None]:
# # plot_station_timeseries("QCS07", ds, ylim= (125, 0), years = [2020, 2021, 2022, 2023, 2024, 2025])
# # plot_station_timeseries("QCS01", ds, ylim= (120, 0), years = [2020, 2021, 2022, 2023, 2024, 2025])
# # plot_station_timeseries("FZH08", ds, years = [2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024, 2025])
# # plot_station_timeseries("DFO2", ds, years = [2020, 2021, 2022, 2023, 2024, 2025])
# plot_station_timeseries("KC10", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (400,0))
# plot_station_timeseries("HKP04", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (400, 0))
# plot_station_timeseries("FZH14", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (400, 0))
plot_station_timeseries("FZH08", ds, years = [2024, 2025], ylim = (400, 0))
# # plot_station_timeseries("FZH14", ds, years = [2023, 2024, 2025])
# # plot_station_timeseries("FZH08", ds, years = [2023, 2024])
# # plot_station_timeseries("FZH04", ds, years = [2020, 2021, 2022, 2023, 2024, 2025])
# # plot_station_timeseries("FZH10", ds, years = [2020, 2021, 2022, 2023, 2024, 2025])
# # plot_station_timeseries("FC2", ds, years = [2020, 2021, 2022, 2023, 2024, 2025])
# # plot_station_timeseries("HKP04", ds, years = [2020, 2021, 2022], ylim = (400, 250))
# # plot_station_timeseries("HKP04", ds, years = [2023], ylim = (400, 250))
# plot_station_timeseries("HKP01", ds, years = [2020], ylim = (270, 0))
# plot_station_timeseries("HKP03", ds, years = [2020], ylim = (400, 0))

# plot_station_timeseries("HKP05", ds, years = [2020], ylim = (240, 0))
# plot_station_timeseries("HKP06", ds, years = [2020], ylim = (400, 0))
# plot_station_timeseries("FZH01", ds, years = [2020], ylim = (260, 0))
# plot_station_timeseries("FZH13", ds, years = [2020], ylim = (220, 0))
# plot_station_timeseries("QCS01", ds, years = [2020], ylim = (150, 0))
# plot_station_timeseries("QCS08", ds, years = [2020], ylim = (150, 0))
# plot_station_timeseries("DFO2", ds, years = [2020], ylim = (350, 0))
# plot_station_timeseries("QCS02", ds, years = [2020], ylim = (400, 0))
# plot_station_timeseries("FC1", ds, years = [2023, 2024])
# plot_station_timeseries("DE2", ds, years = [2020, 2021, 2022, 2023, 2024, 2025])
# plot_station_timeseries("DE1", ds, years = [2020, 2021, 2022, 2023, 2024, 2025])
# plot_station_timeseries("DE3", ds, years = [2020, 2021, 2022, 2023, 2024, 2025])
# plot_station_timeseries("QCS03", ds, years = [2020], ylim = (400, 0))
# plot_station_timeseries("QCS04", ds, years = [2020], ylim = (400, 0))
# plot_station_timeseries("QCS05", ds, years = [2020], ylim = (400, 0))
# plot_station_timeseries("QCS06", ds, years = [2020], ylim = (400, 0))
plot_station_timeseries("FZH08", ds, years = [2024, 2025], ylim = (400, 0))
plot_station_timeseries("QCS07", ds, years = [2024, 2025], ylim = (120, 0))
# plot_station_timeseries("QCS08", ds, years = [2020], ylim = (400, 0))
# plot_station_timeseries("HKP06", ds, years = [2020, 2021, 2022, 2023, 2024, 2025])
# plot_station_timeseries("QCS07", df, ylim = (120, 0), years= [2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024, 2025])
# print("QCS07 max depth:", df[df['station'] == 'QCS07']['depth'].max(), "m")
# print("QCS01 max depth:", df[df['station'] == 'QCS01']['depth'].max(), "m")
# print("QCS02 max depth:", df[df['station'] == 'QCS02']['depth'].max(), "m")
# print("QCS03 max depth:", df[df['station'] == 'QCS03']['depth'].max(), "m")

# plot_station_timeseries("FZH01", ds, years = [2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024, 2025], ylim = (420, 0))
# plot_station_timeseries("FZH04", ds, years = [2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024, 2025], ylim = (420, 0))
# plot_station_timeseries("FZH07", ds, years = [2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024, 2025], ylim = (420, 0))
# plot_station_timeseries("FZH08", ds, years = [2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024, 2025], ylim = (420, 0))
# plot_station_timeseries("FZH13", ds, years = [2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024, 2025], ylim = (420, 0))
# plot_station_timeseries("FZH14", ds, years = [2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023, 2024, 2025], ylim = (420, 0))

## Define the station names
# station_names = [
#     "DE1",
#     "DE2",
#     "DE3",
#     "DFO2",
#     "FC1",
#     "FZH01",
#     "FZH04",
#     "FZH08",
#     "FZH10",
#     "FZH13",
#     "FZH14",
#     "HKP01",
#     "HKP03",
#     "HKP04",
#     "HKP05",
#     "HKP06",
#     "KC10",
#     "QCS01",
#     "QCS02",
#     "QCS03",
#     "QCS04",
#     "QCS05",
#     "QCS06",
#     "QCS07",
#     "QCS08"
# ]

# # Define years
# years = list(range(2010, 2026))

# # Loop over stations and plot
# for station in station_names:
#     print(f"Plotting {station}...")
#     plot_station_timeseries(
#         station,
#         ds,
#         years=years,
#         ylim=(420, 0))

In [None]:
def plot_station_time_series_subplots(
    ds,
    station_ids,
    ylims,
    years=[2023, 2024, 2025]
):
    """
    Plot vertically stacked subplots of temperature time series
    for multiple Hakai stations, each with its own time axis
    but consistent time range.
    """
    import matplotlib.pyplot as plt
    from matplotlib.dates import DateFormatter
    import numpy as np
    import cmocean as cm
    import gsw
    import pandas as pd

    # Calculate height ratios
    max_depths = [yl[0] for yl in ylims]
    max_depth = max(max_depths)
    height_ratios = [d / max_depth for d in max_depths]

    n = len(station_ids)
    fig, axes = plt.subplots(
        nrows=n,
        figsize=(1.2 * 16, 1.2 * 14),
        constrained_layout=True,
        gridspec_kw={'height_ratios': height_ratios}
    )
    fig.suptitle("Temperature Time Series at Stations", fontsize=30)

    if n == 1:
        axes = [axes]

    # Regular depth grid
    regular_depth = np.arange(0, 500, 1)

    # Collect min and max dates across all stations
    all_times_for_limits = []

    station_data = []

    # Preprocess each station to get times and data
    for station_id in station_ids:
        ds_station = ds.where(ds['station'] == station_id, drop=True)
        ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
        ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)

        grouped = ds_station.groupby('time')

        temp_profiles = []
        dens_profiles = []
        valid_times = []

        for t, group in grouped:
            if group.sizes['row'] < 2:
                continue
            depth_vals = group['depth'].values
            temp_vals = group['temperature'].values
            sal_vals = group['salinity'].values
            dens_vals = gsw.sigma0(sal_vals, temp_vals)

            # Interpolate
            temp_interp = np.interp(regular_depth, depth_vals, temp_vals, left=np.nan, right=np.nan)
            dens_interp = np.interp(regular_depth, depth_vals, dens_vals, left=np.nan, right=np.nan)

            temp_profiles.append(temp_interp)
            dens_profiles.append(dens_interp)
            time_dt = pd.Timestamp(t).to_pydatetime()
            valid_times.append(time_dt)
            all_times_for_limits.append(time_dt)

        station_data.append((station_id, valid_times, temp_profiles, dens_profiles))

    if not all_times_for_limits:
        print("No valid data found for any station.")
        return

    # Determine global time range
    tmin = min(all_times_for_limits)
    tmax = max(all_times_for_limits)

    # Now plot each station
    for ax, (station_id, valid_times, temp_profiles, dens_profiles), ylim in zip(
        axes, station_data, ylims
    ):
        if not valid_times:
            print(f"No usable profiles at station {station_id}")
            continue

        temp = np.array(temp_profiles).T
        dens = np.array(dens_profiles).T

        # Plot temperature
        pc = ax.pcolormesh(
            valid_times, regular_depth, temp,
            cmap=cm.cm.thermal, shading='nearest',
            vmin=5.3, vmax=10
        )

        # Density contours
        for levels, color, lw in [
            (np.linspace(24, 27, 7), 'black', 0.5),
            ([25.6], 'white', 2),
            ([25.7], 'lime', 2),
            ([25.8], 'red', 2),
            ([25.9], 'blue', 2)
        ]:
            cf = ax.contour(valid_times, regular_depth, dens,
                            levels=levels, colors=color, linewidths=lw)
            if lw != 0.3:
                ax.clabel(cf, fmt='%1.2f', fontsize=12)

        # Sill reference line
        ax.plot(valid_times, [120] * len(valid_times), color='black', linewidth=2, label='Approx. sill depth')

        ax.invert_yaxis()
        ax.set_ylim(ylim)
        ax.set_ylabel("Depth (m)", fontsize=20)
        ax.set_title(f"{station_id}", fontsize=24)

        # Force all subplots to share the same time limits
        ax.set_xlim(tmin, tmax)

        ax.xaxis.set_major_formatter(DateFormatter('%b %d\n%Y'))
        if ax is not axes[-1]:
            ax.set_xticklabels([])
        else:
            ax.set_xlabel("Date", fontsize=20)
        # Top ticks for each subplot
        ax_top = ax.secondary_xaxis('top')
        ax_top.set_xticks(valid_times)
        ax_top.set_xticklabels(
            [pd.Timestamp(t).strftime('%b %d') for t in valid_times],
            rotation=90,
            fontsize=10)
        ax_top.tick_params(axis='x', direction='out', length=3, width=0.5, pad=2)

    fig.colorbar(pc, ax=axes, label="Temperature (°C)")
    plt.rcParams.update({
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "axes.titlesize": 20,
        "axes.labelsize": 20,
        "figure.titlesize": 24
    })

plot_station_time_series_subplots(
    ds,
    station_ids=["QCS07", "FZH08"],
    ylims=[(130, 0), (370, 0)],
    years=[2024, 2025])

In [None]:
def plot_station_profile_timeseries(station_id, ds, ylim=(450, 0), years=[2023, 2024, 2025]):
    import matplotlib.pyplot as plt
    import numpy as np
    import cmocean as cm
    import gsw
    import pandas as pd
    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 12,
        'figure.titlesize': 20})

    # Filter to the desired station
    ds_station = ds.where(ds['station'] == station_id, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)

    if ds_station.sizes['row'] < 2:
        print(f"No valid data for station {station_id} in years {years}.")
        return

    # Compute sigma_theta
    sigma_theta = gsw.sigma0(ds_station['salinity'], ds_station['temperature'])

    # Group by time, interpolate vertically
    times = np.unique(ds_station['time'].values)
    regular_depth = np.arange(0, 423, 1)
    temp_profiles = []
    dens_profiles = []
    valid_times = []

    for t in times:
        profile = ds_station.where(ds_station['time'] == t, drop=True)
        if profile.sizes['row'] < 2:
            continue
        temp = np.interp(regular_depth, profile['depth'].values, profile['temperature'].values, left=np.nan, right=np.nan)
        dens = np.interp(regular_depth, profile['depth'].values, sigma_theta.where(ds_station['time'] == t, drop=True).values, left=np.nan, right=np.nan)
        temp_profiles.append(temp)
        dens_profiles.append(dens)
        valid_times.append(t)

    if not temp_profiles:
        print(f"No usable profiles at station {station_id}")
        return

    temp = np.array(temp_profiles).T
    density = np.array(dens_profiles).T
    times = np.array(valid_times).astype('datetime64[ns]')

    # Create Viridis color map
    cmap = plt.get_cmap('viridis')
    norm = plt.Normalize(times.min().astype('int64'), times.max().astype('int64'))
    colors = [cmap(norm(t.astype('int64'))) for t in times]

    # Plotting
    fig, axes = plt.subplots(1, 2, figsize=(10, 8), sharey=True)

    # Temperature profiles
    for i, t in enumerate(times):
        axes[0].plot(temp[:, i], regular_depth, color=colors[i], label=pd.Timestamp(t).strftime('%d %b %Y'))
    axes[0].invert_yaxis()
    axes[0].set_ylim(ylim)
    axes[0].set_xlabel("Temperature (°C)")
    axes[0].set_ylabel("Depth (m)")
    axes[0].set_title("Temperature Profiles")

    # Density profiles
    for i, t in enumerate(times):
        axes[1].plot(density[:, i], regular_depth, color=colors[i])
    axes[1].invert_yaxis()
    axes[1].set_ylim(ylim)
    axes[1].set_xlabel("Potential density (kg/m³)")
    axes[1].set_title("Density Profiles")

    # Colorbar
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, ax=axes.ravel().tolist(), orientation='vertical', fraction=0.02, pad=0.02)
    cbar.set_label('Profile Time')
    tick_locs = np.linspace(times.min().astype('int64'), times.max().astype('int64'), 5)
    tick_labels = [pd.Timestamp(np.datetime64(int(t), 'ns')).strftime('%d %b') for t in tick_locs]
    cbar.set_ticks(tick_locs)
    cbar.set_ticklabels(tick_labels)

plot_station_profile_timeseries("HKP04", ds, years=[2020], ylim=(400, 0))

In [None]:
def compute_mean_density_time_series(
    ds,
    station_id,
    depth_range=(130, 350),
    years=[2023, 2024, 2025]
):
    """
    Compute the mean density between two depths over time.
    """
    import numpy as np
    import pandas as pd
    import gsw

    # Filter to the desired station
    ds_station = ds.where(ds['station'] == station_id, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)

    if ds_station.sizes['row'] < 2:
        print(f"No valid data for station {station_id} in years {years}.")
        return None

    # Compute sigma_theta
    sigma_theta = gsw.sigma0(ds_station['salinity'], ds_station['temperature'])

    # Get unique times
    times = np.unique(ds_station['time'].values)
    regular_depth = np.arange(depth_range[0], depth_range[1]+1, 1)

    avg_densities = []
    valid_times = []

    for t in times:
        profile = ds_station.where(ds_station['time'] == t, drop=True)
        if profile.sizes['row'] < 2:
            continue
        # Interpolate density to regular depths
        dens_interp = np.interp(
            regular_depth,
            profile['depth'].values,
            sigma_theta.where(ds_station['time'] == t, drop=True).values,
            left=np.nan,
            right=np.nan
        )
        if np.all(np.isnan(dens_interp)):
            continue
        mean_density = np.nanmean(dens_interp)
        avg_densities.append(mean_density)
        valid_times.append(pd.Timestamp(t))

    if not avg_densities:
        print(f"No usable profiles at station {station_id}")
        return None

    # Create DataFrame
    df = pd.DataFrame({
        'Date': valid_times,
        'MeanDensity': avg_densities
    }).sort_values('Date').reset_index(drop=True)

    return df

df_density = compute_mean_density_time_series(
    ds,
    station_id="HKP04",
    depth_range=(0, 350),
    years=[2024, 2025]
)
print(df_density)

def plot_mean_density_time_series(df, station_id):
    import matplotlib.pyplot as plt

    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(df['Date'], df['MeanDensity'], marker='o', color='navy')
    ax.set_title(f"Mean Density Time Series at {station_id}\n({df['Date'].min().strftime('%Y-%m-%d')} to {df['Date'].max().strftime('%Y-%m-%d')})")
    ax.set_xlabel("Date")
    ax.set_ylabel("Depth averaged density (full profile) (kg/m³)")
    plt.setp(ax.get_xticklabels(), rotation=90, ha='center')
    ax.grid(True)
    plt.tight_layout()
    plt.show()
plot_mean_density_time_series(df_density, "HKP04")

In [None]:
%matplotlib widget
def plot_ts_from_station(ds, station_id, years=[2023, 2024, 2025],
                         depth_range=(100, 450), xlim=None, ylim=None, target_months=None):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    import itertools
    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})

    # Filter and clean
    ds_station = ds.where(ds['station'] == station_id, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)
    if target_months is not None:
        ds_station = ds_station.where(ds_station['time'].dt.month.isin(target_months), drop=True)

    if ds_station.sizes['row'] < 2:
        print(f"No valid data for station {station_id} in years {years}.")
        return

    # ─── Interpolate profiles onto a uniform depth grid ───
    regular_depth = np.arange(0, 423, 1)
    sigma_theta = gsw.sigma0(ds_station['salinity'], ds_station['temperature'])
    times = np.unique(ds_station['time'].values)

    all_S, all_T = [], []
    cmap = plt.colormaps.get_cmap('jet')
    colors = cmap(np.linspace(0, 1, len(times)))
    marker = itertools.cycle(['o', 'D', 's', 'X', 'P', '^', 'v', '<', '>', '*', 'h'])

    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    ax.set_facecolor('lightgrey')

    # Sigma-theta grid
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 200),
                                 np.linspace(5.3, 20, 200))
    sigma = gsw.sigma0(S_grid, T_grid)

    for levels, color, lw in [
            (np.linspace(24, 27, 7), 'black', .51),
                        ([25.6], 'white', 0.5),
                        ([25.7], 'lime', 0.5),
                        ([25.8], 'red', 0.5),
                        ([25.9], 'blue', 0.5),
                        ([26.0], 'black', 0.51),
                        ([26.1], 'purple', 0.5),
                        ([26.2], 'salmon', 0.5),
                        ([26.3], 'yellow', 0.5),
                        ([26.4], 'cyan', 0.5)]:
        cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                        colors=color, linewidths=lw, linestyles='--')
        if lw != 0.5:
            ax.clabel(cs, fmt='%1.2f', fontsize=9, inline=True)

    # Plot each profile
    for i, t in enumerate(times):
        profile = ds_station.where(ds_station['time'] == t, drop=True)
        if profile.sizes['row'] < 2:
            continue

        depth = profile['depth'].values
        temp = profile['temperature'].values
        sal = profile['salinity'].values

        mask = (depth >= depth_range[0]) & (depth <= depth_range[1])
        temp = temp[mask]
        sal = sal[mask]

        temp_masked = ~np.isnan(temp) & ~np.isnan(sal)
        if np.any(temp_masked):
            ax.scatter(sal[temp_masked], temp[temp_masked],
                       color=colors[i],
                       s=90,
                       marker=next(marker),
                       linewidth=0.2,
                       label=pd.Timestamp(t).strftime('%d %b %Y'))

            all_S.extend(sal[temp_masked])
            all_T.extend(temp[temp_masked])

    # Axis limits
    if all_S and all_T:
        all_S = np.array(all_S)
        all_T = np.array(all_T)
        if xlim is None:
            xlim = (np.nanmin(all_S) - 0.1, np.nanmax(all_S) + 0.1)
        if ylim is None:
            ylim = (np.nanmin(all_T) - 0.1, np.nanmax(all_T) + 0.1)

    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel(r"$\theta$ (°C)")
    ax.set_title(f"T–S Diagram at {station_id}")
    ax.legend(fontsize=9, loc='upper right')
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    plt.tight_layout()

# plot_ts_from_station(ds, 'FZH08', years=[2023], depth_range=(200, 450), target_months= [6,7,8,9,10,11,12])
# plot_ts_from_station(ds, 'FZH08', years=[2023], depth_range=(300, 450), target_months= [6,7,8,9,10,11,12])
# plot_ts_from_station(ds, 'HKP04', years=[2023], depth_range=(300, 450), target_months= [6,7,8,9,10,11,12])
# plot_ts_from_station(ds, 'HKP04', years=[2020], depth_range=(10, 450), target_months= [1,2,3,4,5,6,7,8,9,10,11,12])
# plot_ts_from_station(ds, 'HKP03', years=[2020], depth_range=(10, 450), target_months= [1,2,3,4,5,6,7,8,9,10,11,12])
# plot_ts_from_station(ds, 'QCS01', years=[2020], depth_range=(10, 450), target_months= [1,2,3,4,5,6,7,8,9,10,11,12])
# plot_ts_from_station(ds, 'QCS07', years=[2020], depth_range=(10, 450), target_months= [1,2,3,4,5,6,7,8,9,10,11,12])
# plot_ts_from_station(ds, 'DE1', years=[2022], depth_range=(0, 500), target_months= [1,2,3,4,5,6,7,8,9,10,11,12])
# plot_ts_from_station(ds, 'DE1', years=[2023], depth_range=(0, 500), target_months= [1,2,3,4,5,6,7,8,9,10,11,12])
# plot_ts_from_station(ds, 'DE1', years=[2024], depth_range=(0, 500), target_months= [1,2,3,4,5,6,7,8,9,10,11,12])
plot_station_timeseries("BUR3", ds, years = [2020, 2021, 2022, 2023, 2024, 2025], ylim = (320, 0))
plot_ts_from_station(ds, 'BUR3', years=[2023, 2024], depth_range=(0, 500), target_months= [1,2,3,4,5,6,7,8,9,10,11,12])


In [None]:
plot_ts_from_station(ds, 'DE1', years=[2022, 2023], depth_range=(0, 500), target_months= [1,2,3,4,5,6,7,8,9,10,11,12])

In [None]:
# plot_ts_from_station(ds, 'FZH08', years=[2024, 2025], depth_range=(0, 450), target_months= [1,2,3,4,5,6,7,8,9,10,11,12])

In [None]:
# plot_ts_from_station(ds, 'DE1', years=[2020, 2021, 2022, 2023, 2024, 2025], depth_range=(0, 500), target_months= [1,2,3,4,5,6,7,8,9,10,11,12])

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import gsw
import itertools

def plot_ts_compare_stations(ds, station_ids, years=[2023], 
                             depth_range=(100, 450), target_months=None,
                             xlim=None, ylim=None):
    """
    Compare T-S diagrams across multiple stations.
    """
    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 12,
        'figure.titlesize': 20})

    # Setup colors and markers
    base_cmap = plt.cm.get_cmap('tab20')
    station_colors = {sid: base_cmap(i % 10) for i, sid in enumerate(station_ids)}
    marker_cycle = itertools.cycle(['o', 'D', 's', 'X', 'P', '^', 'v', '<', '>', '*', 'h', '+'])

    fig, ax = plt.subplots(figsize=(12, 7))

    # Sigma-theta background grid
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 200),
                                 np.linspace(5.3, 20, 200))
    sigma = gsw.sigma0(S_grid, T_grid)

    # Sigma-theta contours
    for levels, color, lw in [
        (np.linspace(24, 27, 7), 'black', 0.5),
        (np.linspace(26, 27, 11), 'black', 0.5),
        ([26], 'black', 2),
        ([25.6], 'yellow', 2),
        ([25.7], 'lime', 2),
        ([25.8], 'red', 2),
        ([25.9], 'blue', 2),
        ([26.6], 'yellow', 2),
        ([26.7], 'lime', 2),
        ([26.8], 'red', 2),
        ([26.9], 'blue', 2)
    ]:
        cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                        colors=color, linewidths=lw, linestyles='--')
        if lw != 0.5:
            ax.clabel(cs, fmt='%1.2f', fontsize=9, inline=True)

    all_S, all_T = [], []

    # Loop over stations
    for station_id in station_ids:
        ds_station = ds.where(ds['station'] == station_id, drop=True)
        ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
        ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)
        if target_months is not None:
            ds_station = ds_station.where(ds_station['time'].dt.month.isin(target_months), drop=True)

        if ds_station.sizes['row'] < 2:
            print(f"No valid data for station {station_id} in years {years}.")
            continue

        times = np.unique(ds_station['time'].values)
        marker = itertools.cycle(['o', 'D', 's', 'X', 'P', '^', 'v', '<', '>', '*', 'h', '+'])

        for t in times:
            profile = ds_station.where(ds_station['time'] == t, drop=True)
            if profile.sizes['row'] < 2:
                continue

            depth = profile['depth'].values
            temp = profile['temperature'].values
            sal = profile['salinity'].values

            mask = (depth >= depth_range[0]) & (depth <= depth_range[1])
            temp = temp[mask]
            sal = sal[mask]

            valid = ~np.isnan(temp) & ~np.isnan(sal)
            if np.any(valid):
                ax.scatter(
                    sal[valid],
                    temp[valid],
                    color=station_colors[station_id],
                    edgecolors='black',
                    s=90,
                    marker=next(marker),
                    linewidth=0.3,
                    label=f"{station_id}: {pd.Timestamp(t).strftime('%d %b %Y')}"
                )
                all_S.extend(sal[valid])
                all_T.extend(temp[valid])

    # Axis limits
    if all_S and all_T:
        all_S = np.array(all_S)
        all_T = np.array(all_T)
        if xlim is None:
            xlim = (np.nanmin(all_S) - 0.1, np.nanmax(all_S) + 0.1)
        if ylim is None:
            ylim = (np.nanmin(all_T) - 0.1, np.nanmax(all_T) + 0.1)

    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel(r"$\theta$ (°C)")
    ax.set_title("T–S Diagram Comparison")
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.legend(fontsize=9, loc='upper right', ncol=2)
    plt.tight_layout()

plot_ts_compare_stations(
    ds,
    station_ids=['DFO2', 'HKP04',"QCS01", "QCS07" ],
    years=[2020],
    depth_range=(100, 450),
    target_months=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])

plot_ts_compare_stations(
    ds,
    station_ids=["QCS01","QCS02", "QCS07", "QCS08"],
    years=[2020],
    depth_range=(10, 450),
    target_months=[ 5, 6, 7, 8, 9, 10, 11, 12])


# plot_station_timeseries("HKP04", ds, years = [2020], ylim = (400, 0))
# plot_station_timeseries("DFO2", ds, years = [2020], ylim = (400, 0))
# plot_station_timeseries("QCS01", ds, years = [2020], ylim = (400, 0))
# plot_station_timeseries("QCS07", ds, years = [2020], ylim = (400, 0))

In [None]:
plot_ts_compare_stations(
    ds,
    station_ids=["QCS01","QCS02", "QCS07", "QCS08",'DFO2', 'HKP01', "HKP03", "HKP04", "HKP05", "HKP06", "FZH01", "FZH13"],
    years=[2020],
    depth_range=(100, 450),
    target_months=[7])

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import gsw
import itertools

def plot_ts_compare_stations(ds, station_ids, years=[2023],
                             depth_range=(100, 450), target_months=None,
                             xlim=None, ylim=None):
    """
    Compare T-S diagrams across multiple stations.
    """
    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 12,
        'figure.titlesize': 20
    })

    # Setup Viridis colors and markers
    viridis = plt.cm.get_cmap('viridis')
    color_indices = np.linspace(0, 1, len(station_ids))
    station_colors = {sid: viridis(c) for sid, c in zip(station_ids, color_indices)}
    marker_cycle = itertools.cycle(['o', 'D', 's', 'X', 'P', '^', 'v', '<', '>', '*', 'h', '+'])

    fig, ax = plt.subplots(figsize=(12, 7))

    # Sigma-theta background grid
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 200),
                                 np.linspace(5.3, 20, 200))
    sigma = gsw.sigma0(S_grid, T_grid)

    # Sigma-theta contours
    for levels, color, lw in [
        (np.linspace(24, 27, 7), 'black', 0.5),
        (np.linspace(26, 27, 11), 'black', 0.5),
        ([26], 'black', 2),
        ([25.6], 'yellow', 2),
        ([25.7], 'lime', 2),
        ([25.8], 'red', 2),
        ([25.9], 'blue', 2),
        ([26.6], 'yellow', 2),
        ([26.7], 'lime', 2),
        ([26.8], 'red', 2),
        ([26.9], 'blue', 2)
    ]:
        cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                        colors=color, linewidths=lw, linestyles='--')
        if lw != 0.5:
            ax.clabel(cs, fmt='%1.2f', fontsize=9, inline=True)

    all_S, all_T = [], []

    # Loop over stations
    for station_id in station_ids:
        ds_station = ds.where(ds['station'] == station_id, drop=True)
        ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
        ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)
        if target_months is not None:
            ds_station = ds_station.where(ds_station['time'].dt.month.isin(target_months), drop=True)

        if ds_station.sizes['row'] < 2:
            print(f"No valid data for station {station_id} in years {years}.")
            continue

        times = np.unique(ds_station['time'].values)
        marker_iter = itertools.cycle(['o', 'D', 's', 'X', 'P', '^', 'v', '<', '>', '*', 'h', '+'])

        for t in times:
            profile = ds_station.where(ds_station['time'] == t, drop=True)
            if profile.sizes['row'] < 2:
                continue

            depth = profile['depth'].values
            temp = profile['temperature'].values
            sal = profile['salinity'].values

            mask = (depth >= depth_range[0]) & (depth <= depth_range[1])
            temp = temp[mask]
            sal = sal[mask]

            valid = ~np.isnan(temp) & ~np.isnan(sal)
            if np.any(valid):
                ax.scatter(
                    sal[valid],
                    temp[valid],
                    color=station_colors[station_id],
                    edgecolors='black',
                    s=90,
                    marker=next(marker_iter),
                    linewidth=0.3,
                    label=f"{station_id}: {pd.Timestamp(t).strftime('%d %b %Y')}"
                )
                all_S.extend(sal[valid])
                all_T.extend(temp[valid])

    # Axis limits
    if all_S and all_T:
        all_S = np.array(all_S)
        all_T = np.array(all_T)
        if xlim is None:
            xlim = (np.nanmin(all_S) - 0.1, np.nanmax(all_S) + 0.1)
        if ylim is None:
            ylim = (np.nanmin(all_T) - 0.1, np.nanmax(all_T) + 0.1)

    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel(r"$\theta$ (°C)")
    ax.set_title("T–S Diagram Comparison")
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.legend(fontsize=9, loc='upper right', ncol=2)
    plt.tight_layout()

# plot_ts_compare_stations(
#     ds,
#     station_ids=["QCS01","QCS02", "QCS07", "QCS08", 'HKP01', "HKP03", "HKP04", "HKP05", "HKP06", "FZH01", "FZH13",'DFO2'],
#     years=[2020],
#     depth_range=(100, 450),
#     target_months=[3])

# plot_ts_compare_stations(
#     ds,
#     station_ids=["QCS01","QCS02", "QCS07", "QCS08", 'HKP01', "HKP03", "HKP04", "HKP05", "HKP06", "FZH01", "FZH13",'DFO2'],
#     years=[2020],
#     depth_range=(100, 450),
#     target_months=[4])

# plot_ts_compare_stations(
#     ds,
#     station_ids=["QCS01","QCS02", "QCS07", "QCS08", 'HKP01', "HKP03", "HKP04", "HKP05", "HKP06", "FZH01", "FZH13",'DFO2'],
#     years=[2020],
#     depth_range=(100, 450),
#     target_months=[5])

# plot_ts_compare_stations(
#     ds,
#     station_ids=["QCS01","QCS02", "QCS07", "QCS08", 'HKP01', "HKP03", "HKP04", "HKP05", "HKP06", "FZH01", "FZH13",'DFO2'],
#     years=[2020],
#     depth_range=(100, 450),
#     target_months=[6])

# plot_ts_compare_stations(
#     ds,
#     station_ids=["QCS01","QCS02", "QCS07", "QCS08", 'HKP01', "HKP03", "HKP04", "HKP05", "HKP06", "FZH01", "FZH13",'DFO2'],
#     years=[2020],
#     depth_range=(100, 450),
#     target_months=[7])

# plot_ts_compare_stations(
#     ds,
#     station_ids=["QCS01","QCS02", "QCS07", "QCS08", 'HKP01', "HKP03", "HKP04", "HKP05", "HKP06", "FZH01", "FZH13",'DFO2'],
#     years=[2020],
#     depth_range=(100, 450),
#     target_months=[8])

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import gsw
import itertools

def plot_ts_compare_glider_and_stations(
    ds_station,
    ds_glider,
    station_ids,
    along_values,
    years=[2023],
    depth_range=(100, 450),
    shallowest_depth=None,
    target_months=None,
    xlim=None,
    ylim=None
):
    """
    Compare T–S diagrams across stations and glider transects.
    """
    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 9,
        'figure.titlesize': 20})

    fig, ax = plt.subplots(figsize=(12, 7))

    # Sigma-theta background grid
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 200),
                                 np.linspace(5.3, 16, 200))
    sigma = gsw.sigma0(S_grid, T_grid)

    # Contours
    for levels, color, lw in [
        (np.linspace(24, 27, 7), 'black', 0.5),
        (np.linspace(26, 27, 11), 'black', 0.5),
        ([26], 'black', 2),
        ([25.6], 'yellow', 2),
        ([25.7], 'lime', 2),
        ([25.8], 'red', 2),
        ([25.9], 'blue', 2),
        ([26.6], 'yellow', 2),
        ([26.7], 'lime', 2),
        ([26.8], 'red', 2),
        ([26.9], 'blue', 2)
    ]:
        cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                        colors=color, linewidths=lw, linestyles='--')
        if lw != 0.3:
            ax.clabel(cs, fmt='%1.2f', fontsize=9, inline=True, rightside_up=True)

    all_S, all_T = [], []

    # Colors for stations and glider
    base_cmap = plt.cm.get_cmap('tab10')
    station_colors = {sid: base_cmap(i % 10) for i, sid in enumerate(station_ids)}
    glider_cmap = plt.colormaps.get_cmap('viridis')

    # ─── Station data ───
    for station_id in station_ids:
        ds_s = ds_station.where(ds_station['station'] == station_id, drop=True)
        ds_s = ds_s.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
        ds_s = ds_s.where(ds_s['time'].dt.year.isin(years), drop=True)
        if target_months:
            ds_s = ds_s.where(ds_s['time'].dt.month.isin(target_months), drop=True)

        if ds_s.sizes['row'] < 2:
            print(f"No valid station data for {station_id}.")
            continue

        times = np.unique(ds_s['time'].values)
        marker_cycle = itertools.cycle(['o', 'D', 's', 'X', '^', 'v'])

        for t in times:
            profile = ds_s.where(ds_s['time'] == t, drop=True)
            if profile.sizes['row'] < 2:
                continue
            depth = profile['depth'].values
            T = profile['temperature'].values
            S = profile['salinity'].values

            mask = (depth >= depth_range[0]) & (depth <= depth_range[1])
            T = T[mask]
            S = S[mask]
            valid = ~np.isnan(T) & ~np.isnan(S)

            if np.any(valid):
                ax.scatter(
                    S[valid], T[valid],
                    color=station_colors[station_id],
                    edgecolors='black',
                    s=80,
                    marker=next(marker_cycle),
                    label=f"{station_id} {pd.Timestamp(t).strftime('%d %b %Y')}", linewidths= 0.1
                )
                all_S.extend(S[valid])
                all_T.extend(T[valid])

    # ─── Glider data ───
    if isinstance(along_values, (int, float)):
        along_values = [along_values]
    cube = ds_glider.where(ds_glider['along'].isin(along_values), drop=True)
    if shallowest_depth is not None:
        cube = cube.where(cube['depth'] > shallowest_depth, drop=True)

    transects = cube.transect.values
    times_glider = [pd.to_datetime(str(t)[:8]) for t in transects]
    selected_idxs = [i for i, t in enumerate(times_glider)
                     if t.year in years and (target_months is None or t.month in target_months)]

    depth_mask = (cube['depth'].values >= depth_range[0]) & (cube['depth'].values <= depth_range[1])
    depth_idxs = np.where(depth_mask)[0]
    along_marker_map = {v: m for v, m in zip(along_values, itertools.cycle(['*', 'h', '+', '|']))}

    for i, idx in enumerate(selected_idxs):
        for j, along_val in enumerate(cube['along'].values):
            if along_val not in along_marker_map:
                continue
            T = cube['temperature'][idx, depth_idxs, j].values.flatten()
            S = cube['salinity'][idx, depth_idxs, j].values.flatten()
            date = times_glider[idx]
            label = f"Glider {date.strftime('%d %b')} @ {int(along_val)}m"
            mask = ~np.isnan(T) & ~np.isnan(S)
            if np.any(mask):
                ax.scatter(
                    S[mask], T[mask],
                    color=glider_cmap(i / len(selected_idxs)),
                    s=50,
                    marker=along_marker_map[along_val],
                    label=label
                )
                all_S.extend(S[mask])
                all_T.extend(T[mask])

    # Auto limits
    if all_S and all_T:
        all_S = np.array(all_S)
        all_T = np.array(all_T)
        if xlim is None:
            xlim = (np.nanmin(all_S) - 0.1, np.nanmax(all_S) + 0.1)
        if ylim is None:
            ylim = (np.nanmin(all_T) - 0.1, np.nanmax(all_T) + 0.1)

    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel("Temperature (°C)")
    ax.set_title("T–S Comparison: Stations & Glider")
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.legend(fontsize=7, loc='upper right', ncol=2)
    plt.tight_layout()

cube = xr.open_dataset(os.path.expanduser('~/Desktop/Summer 2025 Python/calvert_cube.nc'))
# plot_ts_compare_glider_and_stations(
#     ds_station=ds,
#     ds_glider=cube,
#     station_ids=['DFO2', 'HKP04', "QCS01", "QCS07"],
#     along_values=[40000],
#     years=[2020],
#     depth_range=(10, 450),
#     target_months=[6, 7, 8, 9])

plot_ts_compare_glider_and_stations(
    ds_station=ds,
    ds_glider=cube,
    station_ids=["QCS01","QCS02", "QCS07", "QCS08"],
    along_values=[75000, 60000, 45000, 30000],
    years=[2020],
    depth_range=(100, 450),
    target_months=[6, 7, 8, 9, 10, 11, 12])

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import gsw
import itertools

def plot_ts_colored_by_date(
    ds_station,
    ds_glider,
    station_ids,
    along_values,
    years=[2023],
    depth_range=(100, 450),
    shallowest_depth=None,
    target_months=None,
    xlim=None,
    ylim=None
):
    """
    Plot T–S diagrams with colors assigned by date across all data sources.
    """
    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 8,
        'figure.titlesize': 20
    })

    fig, ax = plt.subplots(figsize=(12, 7))

    # Sigma-theta background grid
    S_grid, T_grid = np.meshgrid(
        np.linspace(28, 35, 300),
        np.linspace(5, 16, 300)
    )
    sigma = gsw.sigma0(S_grid, T_grid)

    # Contours
    cs = ax.contour(S_grid, T_grid, sigma, levels=np.arange(24, 27.5, 0.2),
                    colors='black', linewidths=0.5, linestyles='--')
    ax.clabel(cs, fmt='%1.2f', fontsize=9, inline=True)

    all_S, all_T = [], []

    # Collect all dates first
    all_times = set()

    # Station times
    for station_id in station_ids:
        ds_s = ds_station.where(ds_station['station'] == station_id, drop=True)
        ds_s = ds_s.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
        ds_s = ds_s.where(ds_s['time'].dt.year.isin(years), drop=True)
        if target_months:
            ds_s = ds_s.where(ds_s['time'].dt.month.isin(target_months), drop=True)
        all_times.update(pd.to_datetime(ds_s['time'].values))

    # Glider times
    transects = ds_glider.transect.values
    glider_times = [pd.to_datetime(str(t)[:8]) for t in transects]
    all_times.update(glider_times)

    # Sort and assign colors
    sorted_times = sorted(all_times)
    cmap = plt.get_cmap('viridis')
    date_colors = {t: cmap(i / max(len(sorted_times)-1, 1)) for i, t in enumerate(sorted_times)}

    # ─── Station data ───
    for station_id in station_ids:
        ds_s = ds_station.where(ds_station['station'] == station_id, drop=True)
        ds_s = ds_s.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
        ds_s = ds_s.where(ds_s['time'].dt.year.isin(years), drop=True)
        if target_months:
            ds_s = ds_s.where(ds_s['time'].dt.month.isin(target_months), drop=True)

        if ds_s.sizes['row'] < 2:
            continue

        times = np.unique(ds_s['time'].values)
        for t in times:
            profile = ds_s.where(ds_s['time'] == t, drop=True)
            if profile.sizes['row'] < 2:
                continue
            depth = profile['depth'].values
            T = profile['temperature'].values
            S = profile['salinity'].values

            mask = (depth >= depth_range[0]) & (depth <= depth_range[1])
            T = T[mask]
            S = S[mask]
            valid = ~np.isnan(T) & ~np.isnan(S)

            date = pd.Timestamp(t)

            if np.any(valid):
                ax.scatter(
                    S[valid], T[valid],
                    color=date_colors[date],
                    edgecolors='black',
                    s=80,
                    marker='o',
                    label=f"{date.strftime('%d %b %Y')}"
                )
                all_S.extend(S[valid])
                all_T.extend(T[valid])

    # ─── Glider data ───
    if isinstance(along_values, (int, float)):
        along_values = [along_values]
    cube = ds_glider.where(ds_glider['along'].isin(along_values), drop=True)
    if shallowest_depth is not None:
        cube = cube.where(cube['depth'] > shallowest_depth, drop=True)

    transects = cube.transect.values
    times_glider = [pd.to_datetime(str(t)[:8]) for t in transects]
    selected_idxs = [i for i, t in enumerate(times_glider)
                     if t.year in years and (target_months is None or t.month in target_months)]

    depth_mask = (cube['depth'].values >= depth_range[0]) & (cube['depth'].values <= depth_range[1])
    depth_idxs = np.where(depth_mask)[0]

    for i, idx in enumerate(selected_idxs):
        date = times_glider[idx]
        for j, along_val in enumerate(cube['along'].values):
            T = cube['temperature'][idx, depth_idxs, j].values.flatten()
            S = cube['salinity'][idx, depth_idxs, j].values.flatten()
            mask = ~np.isnan(T) & ~np.isnan(S)
            if np.any(mask):
                ax.scatter(
                    S[mask], T[mask],
                    color=date_colors[date],
                    edgecolors='none',
                    s=40,
                    marker='x',
                    label=f"{date.strftime('%d %b %Y')}"
                )
                all_S.extend(S[mask])
                all_T.extend(T[mask])

    # Auto limits
    if all_S and all_T:
        all_S = np.array(all_S)
        all_T = np.array(all_T)
        if xlim is None:
            xlim = (np.nanmin(all_S) - 0.1, np.nanmax(all_S) + 0.1)
        if ylim is None:
            ylim = (np.nanmin(all_T) - 0.1, np.nanmax(all_T) + 0.1)

    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel("Temperature (°C)")
    ax.set_title("T–S Comparison Colored by Date")
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)

    # Remove duplicate labels
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), fontsize=8, loc='upper right', ncol=2)

    plt.tight_layout()

plot_ts_colored_by_date(
    ds_station=ds,
    ds_glider=cube,
    station_ids=['DFO2', 'HKP04', 'QCS01', 'QCS07'],
    along_values=[40000],
    years=[2020],
    depth_range=(100, 450),
    target_months=[1,2,3,4,5,6, 7, 8, 9,10,11,12])

In [None]:
import gsw
S1, T1 = 33.6892, 6.722
S2, T2 = 33.2209, 7.394

S_mix = (S1/3 + S2*(2/3))
T_mix = (T1/3 + T2*(2/3))

QCS01 = gsw.sigma0(S1, T1)
DFO2 = gsw.sigma0(S2, T2)
sigma_mix = gsw.sigma0(S_mix, T_mix)

print(f"Sigma1 = {QCS01:.2f}")
print(f"Sigma2 = {DFO2:.2f}")
print(f"Sigma_mix = {sigma_mix:.2f}")

In [None]:
def plot_ts_from_station_by_depth(ds, station_id, years=[2023, 2024, 2025],
                                   depth_range=(100, 450), xlim=None, ylim=None, target_months=None):
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import gsw
    import itertools
    from matplotlib.cm import ScalarMappable
    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})

    # ─── Filter and clean ─────────────────────────────
    ds_station = ds.where(ds['station'] == station_id, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)
    if target_months is not None:
        ds_station = ds_station.where(ds_station['time'].dt.month.isin(target_months), drop=True)

    if ds_station.sizes['row'] < 2:
        print(f"No valid data for station {station_id} in years {years}.")
        return

    # ─── Mask depth range ─────────────────────────────
    mask = (ds_station['depth'] >= depth_range[0]) & (ds_station['depth'] <= depth_range[1])
    ds_station = ds_station.where(mask, drop=True)

    # ─── Extract arrays ───────────────────────────────
    times = np.unique(ds_station['time'].values)
    cmap = plt.colormaps.get_cmap('viridis_r')
    marker_cycle = itertools.cycle(['o', 'D', 's', 'X', 'P', '^', 'v', '<', '>', '*', 'h', '+'])

    # ─── Plot ─────────────────────────────────────────
    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))

    # Sigma-theta contours
    S_grid, T_grid = np.meshgrid(np.linspace(28, 35, 200), np.linspace(5.3, 20, 200))
    sigma = gsw.sigma0(S_grid, T_grid)

    for levels, color, lw in [
        (np.linspace(24, 27, 7), 'black', 0.5),
        (np.linspace(26, 27, 11), 'black', 0.5),
        ([26], 'black', 2),
        ([25.6], 'yellow', 2), ([25.7], 'lime', 2), ([25.8], 'red', 2), ([25.9], 'blue', 2),
        ([26.6], 'yellow', 2), ([26.7], 'lime', 2), ([26.8], 'red', 2), ([26.9], 'blue', 2)
    ]:
        cs = ax.contour(S_grid, T_grid, sigma, levels=levels,
                        colors=color, linewidths=lw, linestyles='--')
        if lw != 0.5:
            ax.clabel(cs, fmt='%1.2f', fontsize=9, inline=True)

    all_S, all_T, all_depth = [], [], []

    for t in times:
        marker = next(marker_cycle)
        profile = ds_station.where(ds_station['time'] == t, drop=True)
        if profile.sizes['row'] < 2:
            continue

        sal = profile['salinity'].values
        temp = profile['temperature'].values
        depth = profile['depth'].values

        valid = ~np.isnan(temp) & ~np.isnan(sal)
        if np.sum(valid) < 1:
            continue

        sc = ax.scatter(sal[valid], temp[valid],
                        c=depth[valid],
                        cmap='viridis',
                        marker=marker,
                        edgecolors='black',
                        s=60,
                        linewidth=0.2,
                        label=pd.Timestamp(t).strftime('%d %b %Y'))

        all_S.extend(sal[valid])
        all_T.extend(temp[valid])
        all_depth.extend(depth[valid])

    # Axis limits
    if all_S and all_T:
        all_S = np.array(all_S)
        all_T = np.array(all_T)
        if xlim is None:
            xlim = (np.nanmin(all_S) - 0.1, np.nanmax(all_S) + 0.1)
        if ylim is None:
            ylim = (np.nanmin(all_T) - 0.1, np.nanmax(all_T) + 0.1)
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)

    # Labels
    ax.set_xlabel("Salinity (psu)")
    ax.set_ylabel(r"$\theta$ (°C)")
    ax.set_title(f"T–S Diagram at {station_id} (colored by depth)")

    # Colorbar
    sm = ScalarMappable(cmap='viridis')
    sm.set_array(np.array(all_depth))
    cbar = plt.colorbar(sm, ax=ax)
    cbar.set_label("Depth (m)")

    ax.legend(fontsize=8, loc='upper right')
    plt.tight_layout()

plot_ts_from_station_by_depth(ds, 'HKP04', years=[2020], depth_range=(10, 450))
plot_ts_from_station_by_depth(ds, 'HKP04', years=[2021, 2022], depth_range=(10, 450))
plot_ts_from_station_by_depth(ds, 'HKP04', years=[2022], depth_range=(10, 450))
plot_ts_from_station_by_depth(ds, 'QCS01', years=[2020], depth_range=(10, 450))
plot_ts_from_station_by_depth(ds, 'QCS07', years=[2020], depth_range=(10, 450))
plot_ts_from_station_by_depth(ds, 'HKP03', years=[2020], depth_range=(10, 450))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import gsw

# ─── Initial Conditions ───
nz = 100                    # number of depth levels
z = np.linspace(0, 400, nz)  # depth from 0 to 400 m
dz = z[1] - z[0]

# Linear stratification: cold & fresh at surface, warm & salty at depth
T = np.linspace(6.4, 7.5, nz)   # temperature (°C)
S = np.linspace(32.9, 33.5, nz)  # salinity (psu)

# Vertical diffusivity (m²/s)
Kz = 1e-4

# Time stepping
dt = 3600 * 6  # 6-hour time step in seconds
n_steps = int((180 * 24 * 3600) // dt)  # simulate 6 months

# Save profiles every 30 days
save_interval = int((30 * 24 * 3600) // dt)
saved_T, saved_S = [T.copy()], [S.copy()]
saved_labels = ["Start"]

# ─── Diffusion Step Function ───
def diffuse(profile, Kz, dz, dt):
    new = profile.copy()
    new[1:-1] += dt * Kz * (profile[2:] - 2 * profile[1:-1] + profile[:-2]) / dz**2
    return new

# ─── Time Loop ───
for step in range(1, n_steps + 1):
    T = diffuse(T, Kz, dz, dt)
    S = diffuse(S, Kz, dz, dt)

    if step % save_interval == 0:
        saved_T.append(T.copy())
        saved_S.append(S.copy())
        saved_labels.append(f"Day {int(step * dt / 86400)}")

# ─── Plot evolution of select depth layers ───
fig, ax = plt.subplots(figsize=(8, 6))

# Background isopycnals
S_grid, T_grid = np.meshgrid(np.linspace(32, 34, 200), np.linspace(5, 9, 200))
sigma = gsw.sigma0(S_grid, T_grid)
cs = ax.contour(S_grid, T_grid, sigma, levels=np.arange(24, 28, 0.2), colors='gray', linewidths=0.5)
ax.clabel(cs, inline=True, fontsize=8, fmt='%1.1f')

# Plot how individual layers evolve
depth_indices = [10, 30, 50, 70, 90]  # e.g., ~40, 120, 200, 280, 360m
colors = plt.cm.viridis(np.linspace(0, 1, len(depth_indices)))

for i, idx in enumerate(depth_indices):
    s_traj = [S[idx] for S in saved_S]
    t_traj = [T[idx] for T in saved_T]
    ax.plot(s_traj, t_traj, '-o', color=colors[i], label=f"{int(z[idx])} m")

ax.set_xlabel("Salinity (psu)")
ax.set_ylabel("Temperature (°C)")
ax.set_title("T–S Evolution at Select Depths (Diapycnal Mixing)")
ax.legend()
plt.tight_layout()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cmocean.cm as cm
import gsw
import pandas as pd
import xarray as xr

def plot_density_gradient(ds, station_id, years=[2023, 2024, 2025], depth_range=(0, 420)):
    # Filter to station and clean
    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})
    ds_station = ds.where(ds['station'] == station_id, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)

    if ds_station.sizes['row'] < 2:
        print(f"No valid data for station {station_id}.")
        return

    # Interpolation grid
    z = np.arange(depth_range[0], depth_range[1] + 1, 1)
    times = np.unique(ds_station['time'].values)
    drho_dz_all = []

    valid_times = []

    for t in times:
        profile = ds_station.where(ds_station['time'] == t, drop=True)
        if profile.sizes['row'] < 2:
            continue

        depth = profile['depth'].values
        temp = profile['temperature'].values
        sal = profile['salinity'].values

        if len(depth) < 3:
            continue

        # Interpolate T/S
        temp_i = np.interp(z, depth, temp, left=np.nan, right=np.nan)
        sal_i = np.interp(z, depth, sal, left=np.nan, right=np.nan)

        # Compute sigma0
        sigma_i = gsw.sigma0(sal_i, temp_i)

        # Compute vertical gradient d(sigma)/dz
        drho_dz = np.gradient(sigma_i, z)
        drho_dz_all.append(drho_dz)
        valid_times.append(t)

    if not drho_dz_all:
        print("No valid profiles after interpolation.")
        return

    # Convert to array
    drho_dz_array = np.array(drho_dz_all).T  # shape: (depth, time)
    time_axis = pd.to_datetime(valid_times)

    # Plot
    fig, ax = plt.subplots(figsize=(12, 6))
    pc = ax.pcolormesh(time_axis, z, drho_dz_array, cmap=cm.balance_r, shading='auto', vmin=-0.02, vmax=0.02)
    ax.invert_yaxis()
    ax.set_title(f"dρ/dz at {station_id}")
    ax.set_xlabel("Time")
    ax.set_ylabel("Depth (m)")
    cbar = plt.colorbar(pc, ax=ax, label=r"$\partial \sigma_\theta / \partial z$ (kg m$^{-4}$)")

    plt.tight_layout()

plot_density_gradient(ds, station_id='HKP04', years=[2023, 2024, 2025])

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import gsw
import pandas as pd

def plot_density_and_gradient_profiles(ds, station_id, years=[2023, 2024, 2025], depth_range=(0, 420), n_profiles=6):
    ds_station = ds.where(ds['station'] == station_id, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)

    times = np.unique(ds_station['time'].values)
    if len(times) < 1:
        print("No profiles found.")
        return

    selected_times = np.linspace(0, len(times)-1, min(n_profiles, len(times)), dtype=int)

    for i in selected_times:
        t = times[i]
        profile = ds_station.where(ds_station['time'] == t, drop=True)
        if profile.sizes['row'] < 3:
            continue

        depth = profile['depth'].values
        temp = profile['temperature'].values
        sal = profile['salinity'].values

        mask = ~np.isnan(depth) & ~np.isnan(temp) & ~np.isnan(sal)
        depth, temp, sal = depth[mask], temp[mask], sal[mask]

        if len(depth) < 3:
            continue

        sort_idx = np.argsort(depth)
        depth = depth[sort_idx]
        temp = temp[sort_idx]
        sal = sal[sort_idx]

        sigma = gsw.sigma0(sal, temp)
        drho_dz = np.gradient(sigma, depth)

        # 10 m rolling average (interpolate to uniform grid first)
        df = pd.DataFrame({'depth': depth, 'drho_dz': drho_dz})
        df = df.set_index('depth').sort_index()
        df = df.rolling(window=10, min_periods=1, center=True).mean()
        depth_smooth = df.index.values
        drho_dz_smooth = df['drho_dz'].values

        # Plot
        fig, ax1 = plt.subplots(figsize=(5, 6))
        ax1.plot(sigma, depth, color='black', label=r"$\sigma_\theta$")
        ax1.set_xlabel(r"$\sigma_\theta$ (kg m$^{-3}$)")
        ax1.set_ylabel("Depth (m)")
        ax1.invert_yaxis()
        ax1.set_title(f"{station_id} — {pd.Timestamp(t).strftime('%d %b %Y')}")

        ax2 = ax1.twiny()
        ax2.plot(drho_dz_smooth, depth_smooth, color='blue', linestyle='-', label=r"$\partial \sigma_\theta / \partial z$")
        ax2.plot(drho_dz_smooth*0, depth_smooth, color='black', linestyle='--')
        ax2.set_xlabel(r"$\partial \sigma_\theta / \partial z$ (kg m$^{-4}$)")
        ax2.spines['top'].set_color('blue')
        ax2.tick_params(axis='x', colors='blue')
        ax2.set_xlim(-0.1, 0.1)

        # Legends
        lines1, labels1 = ax1.get_legend_handles_labels()
        lines2, labels2 = ax2.get_legend_handles_labels()
        ax1.legend(lines1 + lines2, labels1 + labels2, loc='lower right')

        plt.tight_layout()

plot_density_and_gradient_profiles(ds, 'HKP04', years=[2023], n_profiles=5)

In [None]:
def plot_stations_on_bathymetry(
    station_list,
    ds,
    topo_file=os.path.expanduser('~/Desktop/Summer 2025 Python/british_columbia_3_msl_2013.nc'),
    lat_buffer=1.2, lon_buffer=1.4,
    lon_bounds=None, lat_bounds=None,
    deepest=500,
    title="Bathymetry Map with Selected Stations"):
    
    import xarray as xr
    import matplotlib.pyplot as plt
    import cartopy.crs as ccrs
    import cartopy.feature as cfeature
    import numpy as np
    import cmocean as cm

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})

    # Extract valid stations with lat/lon info
    ds_valid = ds.dropna(dim='row', subset=['station', 'latitude', 'longitude'])
    station_vals = ds_valid['station'].values
    lat_vals = ds_valid['latitude'].values
    lon_vals = ds_valid['longitude'].values

    # Extract unique station-lat-lon combos
    unique_triplets = set()
    station_coords = []
    for s, lat, lon in zip(station_vals, lat_vals, lon_vals):
        key = (s, lat, lon)
        if key not in unique_triplets:
            unique_triplets.add(key)
            station_coords.append({'station': s, 'latitude': lat, 'longitude': lon})

    # Filter for selected stations
    station_coords = [row for row in station_coords if row['station'] in station_list]
    if not station_coords:
        print("No matching stations found.")
        return

    # Determine map bounds
    lats = [row['latitude'] for row in station_coords]
    lons = [row['longitude'] for row in station_coords]

    if lon_bounds is None or lat_bounds is None:
        min_lat, max_lat = min(lats) - lat_buffer, max(lats) + lat_buffer
        min_lon, max_lon = min(lons) - lon_buffer, max(lons) + lon_buffer
    else:
        min_lon, max_lon = lon_bounds
        min_lat, max_lat = lat_bounds

    # Load and crop bathymetry
    topo = xr.open_dataset(topo_file)
    topo = topo.sel(lon=slice(min_lon, max_lon), lat=slice(min_lat, max_lat))
    depth = -topo['Band1']

    # Plot
    fig, ax = plt.subplots(figsize=(12, 9), subplot_kw={'projection': ccrs.PlateCarree()})
    ax.set_extent([min_lon, max_lon, min_lat, max_lat], crs=ccrs.PlateCarree())

    levels = np.linspace(0, deepest, 22)
    cf = ax.contourf(topo['lon'], topo['lat'], depth, levels=levels, cmap=cm.cm.deep, extend='both')
    cs = ax.contour(topo['lon'], topo['lat'], depth, levels=levels, colors='k', linewidths=0.3)
    ax.clabel(cs, fmt='%d', fontsize=4)
    plt.colorbar(cf, ax=ax, label='Depth (m)')

    # Plot selected stations
    for row in station_coords:
        ax.plot(row['longitude'], row['latitude'], marker='o', color='red', markersize=10, transform=ccrs.PlateCarree())
        ax.text(row['longitude'] + 0.01, row['latitude'], row['station'], fontsize=10, transform=ccrs.PlateCarree())

    # Gridlines and labels
    gl = ax.gridlines(draw_labels=True, linestyle='--', alpha=0.3)
    gl.top_labels = False
    gl.right_labels = False

    ax.set_title(title)
    ax.set_xlabel("Longitude")
    ax.set_ylabel("Latitude")
    ax.set_aspect(1 / np.cos(np.deg2rad((min_lat + max_lat) / 2)))

# Choose your station
stations_to_plot = ['QCS07',"FZH01","FZH13", "FZH08" , "FZH14", 'HKP06', 'HKP04', 'DFO2', "KC10", "HAK1", "DE1"]

plot_stations_on_bathymetry(
    stations_to_plot,
    ds_stations,
    lon_bounds=(-128.5, -127.2),
    lat_bounds=(51.3, 52.3))


In [None]:
# # Define bounding box for the region of interest (adjust as needed)
# lat_min, lat_max = 51.405, 51.406
# lon_min, lon_max = -127.88, -127.87

# # Crop the bathymetry
# region = topo.sel(lat=slice(lat_min, lat_max), lon=slice(lon_min, lon_max))
# depth_region = -region['Band1']  # convert to positive depth

# # Find deepest point
# deepest_depth = depth_region.max().item()
# deepest_location = np.unravel_index(np.nanargmax(depth_region.values), depth_region.shape)
# deepest_lat = region['lat'].values[deepest_location[0]]
# deepest_lon = region['lon'].values[deepest_location[1]]

# print(f"Deepest depth: {deepest_depth:.2f} m at lat={deepest_lat:.5f}, lon={deepest_lon:.5f}")

In [None]:
# Define region of interest
lat_min, lat_max = 51.31, 51.34
lon_min, lon_max = -127.89, -127.87

# Extract region and convert to positive depth (bathymetry is negative)
region = topo.sel(lat=slice(lat_min, lat_max), lon=slice(lon_min, lon_max))
depth_region = -region['Band1']

# Check for valid data
if np.isnan(depth_region).all():
    print("No valid depth data in this region.")
else:
    shallowest_depth = depth_region.min().item()
    idx = np.unravel_index(np.nanargmin(depth_region.values), depth_region.shape)
    shallowest_lat = region['lat'].values[idx[0]]
    shallowest_lon = region['lon'].values[idx[1]]

    print(f"Shallowest depth: {shallowest_depth:.2f} m at lat={shallowest_lat:.5f}, lon={shallowest_lon:.5f}")

In [None]:
import xarray as xr
import numpy as np
import os

# Load the bathymetry dataset
topo_file = os.path.expanduser('~/Desktop/Summer 2025 Python/british_columbia_3_msl_2013.nc')
topo = xr.open_dataset(topo_file)

# Coordinates of QCS07
lat_qcs07 = 51.3815
lon_qcs07 = -127.9063

# Interpolate depth at that location (Band1 is negative depth)
depth_qcs07 = -topo['Band1'].interp(lat=lat_qcs07, lon=lon_qcs07).item()

print(f"Interpolated depth at QCS07 ({lat_qcs07}, {lon_qcs07}): {depth_qcs07:.2f} m")

In [None]:
def plot_oxygen_timeseries(station_id, ds, ylim=(450, 0), years=[2023, 2024, 2025], show_sill=True, oxymax = 200):
    import matplotlib.pyplot as plt
    from matplotlib.dates import DateFormatter
    import numpy as np
    import cmocean as cm
    import gsw
    import pandas as pd
    from matplotlib.colors import TwoSlopeNorm
    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})

    # ─── Filter and clean ─────────────────────────────
    ds_station = ds.where(ds['station'] == station_id, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time', 'dissolved_oxygen_ml_l'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)

    if ds_station.sizes['row'] < 2:
        print(f"No valid data for station {station_id} in years {years}.")
        return

    # ─── Compute pressure, sigma_theta, and convert oxygen ─
    pressure = gsw.p_from_z(-ds_station['depth'].values, 51.7)
    sigma_theta = gsw.sigma0(ds_station['salinity'], ds_station['temperature'])
    oxygen_umolkg = ds_station['dissolved_oxygen_ml_l'].values * 44.661

    # ─── Group by time, interpolate ───────────────────
    regular_depth = np.arange(0, 500, 1)
    times = np.unique(ds_station['time'].values)
    oxy_profiles, dens_profiles, valid_times = [], [], []

    for t in times:
        profile = ds_station.where(ds_station['time'] == t, drop=True)
        if profile.sizes['row'] < 2:
            continue
        oxy = np.interp(regular_depth, profile['depth'].values,
                        oxygen_umolkg[ds_station['time'].values == t],
                        left=np.nan, right=np.nan)
        dens = np.interp(regular_depth, profile['depth'].values,
                         sigma_theta.where(ds_station['time'] == t, drop=True).values,
                         left=np.nan, right=np.nan)
        oxy_profiles.append(oxy)
        dens_profiles.append(dens)
        valid_times.append(t)

    if not oxy_profiles:
        print(f"No usable profiles at station {station_id}")
        return

    oxy = np.array(oxy_profiles).T
    density = np.array(dens_profiles).T
    times = np.array(valid_times).astype('datetime64[ns]')
    oxymax = oxymax if oxymax is not None else np.nanmax(oxy)
    
    # ─── Plot ──────────────────────────────────────────
    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    norm = TwoSlopeNorm(vmin=0, vcenter=60, vmax=np.nanmax(oxy))
    norm = TwoSlopeNorm(vmin=0, vcenter=60, vmax=oxymax)
    pc = ax.pcolormesh(times, regular_depth, oxy, cmap=cm.cm.balance_r, norm=norm, shading='nearest')

    # ─── Density contours ─────────────────────────────
    contour_settings = [
            (np.linspace(24, 27, 7), 'black', .51),
                        ([25.6], 'white', 0.5),
                        ([25.7], 'lime', 0.5),
                        ([25.8], 'red', 0.5),
                        ([25.9], 'blue', 0.5),
                        ([26.0], 'black', 0.51),
                        ([26.1], 'purple', 0.5),
                        ([26.2], 'salmon', 0.5),
                        ([26.3], 'yellow', 0.5),
                        ([26.4], 'cyan', 0.5)]

    for levels, color, lw in contour_settings:
        cf = ax.contour(times, regular_depth, density, levels=levels, colors=color, linewidths=lw)
        if lw != 0.5:
            ax.clabel(cf, fmt='%1.2f', fontsize=8)

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})

    # ─── Hypoxic threshold ────────────────────────────
    cs = ax.contour(times, regular_depth, oxy, levels=[60], colors='red', linewidths=2)
    ax.clabel(cs, fmt='%1.1f')

    if show_sill:
        ax.plot(times, [130] * len(times), color='black', linestyle='--', linewidth=1.5)
        ax.plot(times, [120] * len(times), color='black', linestyle='--', linewidth=1.5)

    ax.set_title(f"Oxygen Time Series at {station_id}")
    ax.set_xlabel("Date")
    ax.set_ylabel("Depth (m)")
    ax.invert_yaxis()
    ax.set_ylim(ylim)
    fig.colorbar(pc, ax=ax, label="Oxygen (μmol/kg)")
    ax.xaxis.set_major_formatter(DateFormatter('%b %d\n%Y'))

    # ─── Secondary top axis ────────────────────────────
    ax_top = ax.secondary_xaxis('top')
    ax_top.set_xticks(times)
    ax_top.set_xticklabels([pd.Timestamp(t).strftime('%b %-d') for t in times], rotation=90, fontsize=8)
    ax_top.tick_params(axis='x', direction='out', length=3, width=0.5, pad=2)

    plt.tight_layout()

In [None]:
def plot_oxygen_timeseries(station_id, ds, ylim=(450, 0), years=[2023, 2024, 2025],
                           show_sill=True, oxymax=200, show_density_contours=True):
    import matplotlib.pyplot as plt
    from matplotlib.dates import DateFormatter
    import numpy as np
    import cmocean as cm
    import gsw
    import pandas as pd
    from matplotlib.colors import TwoSlopeNorm

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})

    ds_station = ds.where(ds['station'] == station_id, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time', 'dissolved_oxygen_ml_l'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)

    if ds_station.sizes['row'] < 2:
        print(f"No valid data for station {station_id} in years {years}.")
        return

    pressure = gsw.p_from_z(-ds_station['depth'].values, 51.7)
    sigma_theta = gsw.sigma0(ds_station['salinity'], ds_station['temperature'])
    oxygen_umolkg = ds_station['dissolved_oxygen_ml_l'].values * 44.661

    regular_depth = np.arange(0, 500, 1)
    times = np.unique(ds_station['time'].values)
    oxy_profiles, dens_profiles, valid_times = [], [], []

    for t in times:
        profile = ds_station.where(ds_station['time'] == t, drop=True)
        if profile.sizes['row'] < 2:
            continue
        oxy = np.interp(regular_depth, profile['depth'].values,
                        oxygen_umolkg[ds_station['time'].values == t],
                        left=np.nan, right=np.nan)
        dens = np.interp(regular_depth, profile['depth'].values,
                         sigma_theta.where(ds_station['time'] == t, drop=True).values,
                         left=np.nan, right=np.nan)
        oxy_profiles.append(oxy)
        dens_profiles.append(dens)
        valid_times.append(t)

    if not oxy_profiles:
        print(f"No usable profiles at station {station_id}")
        return

    oxy = np.array(oxy_profiles).T
    density = np.array(dens_profiles).T
    times = np.array(valid_times).astype('datetime64[ns]')
    oxymax = oxymax if oxymax is not None else np.nanmax(oxy)

    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    norm = TwoSlopeNorm(vmin=0, vcenter=60, vmax=oxymax)
    pc = ax.pcolormesh(times, regular_depth, oxy, cmap=cm.cm.balance_r, norm=norm, shading='nearest')

    if show_density_contours:
        contour_settings = [
            (np.linspace(24, 27, 7), 'black', .51),
            ([25.6], 'white', 0.5),
            ([25.7], 'lime', 0.5),
            ([25.8], 'red', 0.5),
            ([25.9], 'blue', 0.5),
            ([26.0], 'black', 0.51),
            ([26.1], 'purple', 0.5),
            ([26.2], 'salmon', 0.5),
            ([26.3], 'yellow', 0.5),
            ([26.4], 'cyan', 0.5)
        ]
        for levels, color, lw in contour_settings:
            cf = ax.contour(times, regular_depth, density, levels=levels, colors=color, linewidths=lw)
            if lw != 0.5:
                ax.clabel(cf, fmt='%1.2f', fontsize=8)
    else:
        # Oxygen contours at intervals of 20 μmol/kg
        base_levels = np.arange(0, np.nanmax(oxy), 10)
        hypoxic_level = 60
        # Black contours (except hypoxic line)
        other_levels = [lvl for lvl in base_levels if not np.isclose(lvl, hypoxic_level)]
        if other_levels:
            cs_other = ax.contour(times, regular_depth, oxy, levels=other_levels, colors='black', linewidths=0.5)
            ax.clabel(cs_other, fmt='%1.0f', fontsize=7)
        # Red contour at 60 μmol/kg
        cs_red = ax.contour(times, regular_depth, oxy, levels=[hypoxic_level], colors='red', linewidths=2)
        ax.clabel(cs_red, fmt='%1.0f', fontsize=8)

    if show_sill:
        ax.plot(times, [130] * len(times), color='black', linestyle='--', linewidth=1.5)
        ax.plot(times, [120] * len(times), color='black', linestyle='--', linewidth=1.5)

    ax.set_title(f"Oxygen Time Series at {station_id}")
    ax.set_xlabel("Date")
    ax.set_ylabel("Depth (m)")
    ax.invert_yaxis()
    ax.set_ylim(ylim)
    fig.colorbar(pc, ax=ax, label="Oxygen (μmol/kg)")
    ax.xaxis.set_major_formatter(DateFormatter('%b %d\n%Y'))

    ax_top = ax.secondary_xaxis('top')
    ax_top.set_xticks(times)
    ax_top.set_xticklabels([pd.Timestamp(t).strftime('%b %-d') for t in times], rotation=90, fontsize=8)
    ax_top.tick_params(axis='x', direction='out', length=3, width=0.5, pad=2)

    plt.tight_layout()

plot_oxygen_timeseries("DE1", ds, ylim = (500, 0), years=[2020, 2021, 2022, 2023, 2024, 2025], oxymax = 160, show_density_contours=False)

In [None]:
# plot_oxygen_timeseries("QCS07", ds, ylim = (130, 0),years=[2020, 2021, 2022, 2023, 2024, 2025])
# plot_oxygen_timeseries("QCS01", ds, ylim = (130, 0),years=[2020, 2021, 2022, 2023, 2024, 2025])
# plot_oxygen_timeseries("FZH08", ds, ylim = (400, 0), years=[2020, 2021, 2022, 2023, 2024, 2025])
# plot_oxygen_timeseries("DFO2", ds, ylim = (400, 0), years=[2020, 2021, 2022, 2023, 2024, 2025])
# plot_oxygen_timeseries("KC10", ds, ylim = (400, 0), years=[2020, 2021, 2022, 2023, 2024, 2025])


In [None]:
plot_oxygen_timeseries("DE1", ds, ylim = (500, 0), years=[2020, 2021, 2022, 2023, 2024, 2025], oxymax = 160)
plot_station_timeseries("DE1", ds, ylim = (500, 0), years=[2020, 2021, 2022, 2023, 2024, 2025])
# plot_oxygen_timeseries("DE2", ds, ylim = (500, 0), years=[2020, 2021, 2022, 2023, 2024, 2025], oxymax = 160)
# plot_oxygen_timeseries("DE3", ds, ylim = (500, 0), years=[2020, 2021, 2022, 2023, 2024, 2025], oxymax = 160)

In [None]:
plot_oxygen_timeseries("DFO2", ds, ylim = (300, 0), years=[2020, 2021, 2022, 2023, 2024, 2025], oxymax = 160)
plot_station_timeseries("DFO2", ds, ylim = (300, 0), years=[2020, 2021, 2022, 2023, 2024, 2025])

In [None]:
plot_oxygen_timeseries("FZH08", ds, ylim = (420, 0), years=[2020, 2021, 2022, 2023, 2024, 2025], oxymax = 160)
plot_station_timeseries("FZH08", ds, ylim = (420, 0), years=[2020, 2021, 2022, 2023, 2024, 2025])

In [None]:
plot_oxygen_timeseries("FZH08", ds, ylim = (400, 0), years=[2020])

In [None]:
np.unique(ds['station'].values)

In [None]:
def plot_density_timeseries(station_id, ds, ylim=(450, 0), years=[2023, 2024, 2025], show_sill=True):
    import matplotlib.pyplot as plt
    from matplotlib.dates import DateFormatter
    import numpy as np
    import gsw
    import pandas as pd
    import cmocean as cm

    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})
    # ─── Filter and clean ─────────────────────────────
    ds_station = ds.where(ds['station'] == station_id, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)

    if ds_station.sizes['row'] < 2:
        print(f"No valid data for station {station_id} in years {years}.")
        return

    # ─── Compute sigma_theta ──────────────────────────
    sigma_theta = gsw.sigma0(ds_station['salinity'], ds_station['temperature'])

    # ─── Group by time, interpolate ───────────────────
    regular_depth = np.arange(0, 423, 1)
    times = np.unique(ds_station['time'].values)
    dens_profiles, valid_times = [], []

    for t in times:
        profile = ds_station.where(ds_station['time'] == t, drop=True)
        if profile.sizes['row'] < 2:
            continue
        dens = np.interp(regular_depth, profile['depth'].values,
                         sigma_theta.where(ds_station['time'] == t, drop=True).values,
                         left=np.nan, right=np.nan)
        dens_profiles.append(dens)
        valid_times.append(t)

    if not dens_profiles:
        print(f"No usable profiles at station {station_id}")
        return

    density = np.array(dens_profiles).T
    times = np.array(valid_times).astype('datetime64[ns]')

    # ─── Create meshgrid for pcolormesh ───────────────
    time_mesh, depth_mesh = np.meshgrid(times, regular_depth)

    # ─── Plot ─────────────────────────────────────────
    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    pc = ax.pcolormesh(time_mesh, depth_mesh, density, shading='nearest',
                       cmap=cm.cm.dense, vmin=25, vmax=26.2)
    
    if show_sill:
        ax.plot(times, [130] * len(times), color='black', linestyle='--', linewidth=1.5)
        ax.plot(times, [120] * len(times), color='black', linestyle='--', linewidth=1.5)

    ax.set_title(f"Density Time Series at {station_id}")
    ax.set_xlabel("Date")
    ax.set_ylabel("Depth (m)")
    ax.invert_yaxis()
    ax.set_ylim(ylim)
    fig.colorbar(pc, ax=ax, label="Potential Density (σ₀, kg/m³)")
    ax.xaxis.set_major_formatter(DateFormatter('%b %d\n%Y'))

    ax_top = ax.secondary_xaxis('top')
    ax_top.set_xticks(times)
    ax_top.set_xticklabels([pd.Timestamp(t).strftime('%b %-d') for t in times], rotation=90, fontsize=8)
    ax_top.tick_params(axis='x', direction='out', length=3, width=0.5, pad=2)

    plt.tight_layout()

# plot_density_timeseries("FZH08", ds, years=[2024, 2025])
plot_density_timeseries("HKP04", ds, years=[2020, 2021, 2022, 2023], ylim = (400, 50))
plot_density_timeseries("KC10", ds, years=[2020, 2021, 2022, 2023], ylim = (400, 50))

In [None]:
def plot_density_change_at_depth(station_id, ds, depth_target=300, years=[2020, 2021, 2022, 2023, 2024, 2025]):
    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    import gsw
    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})

    # ─── Filter dataset ───────────────────────────────
    ds_station = ds.where(ds['station'] == station_id, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)

    if ds_station.sizes['row'] < 2:
        print(f"No valid data for station {station_id} in years {years}.")
        return

    # ─── Compute sigma_theta ──────────────────────────
    sigma_theta = gsw.sigma0(ds_station['salinity'], ds_station['temperature'])

    # ─── Interpolate to fixed depth ───────────────────
    times = np.unique(ds_station['time'].values)
    density_at_depth = []
    valid_times = []

    for t in times:
        profile = ds_station.where(ds_station['time'] == t, drop=True)
        if profile.sizes['row'] < 2:
            continue
        depth_vals = profile['depth'].values
        sigma_vals = sigma_theta.where(ds_station['time'] == t, drop=True).values
        if np.nanmin(depth_vals) <= depth_target <= np.nanmax(depth_vals):
            interpolated = np.interp(depth_target, depth_vals, sigma_vals, left=np.nan, right=np.nan)
            density_at_depth.append(interpolated)
            valid_times.append(t)

    if len(density_at_depth) < 2:
        print("Not enough valid profiles for comparison.")
        return

    density_at_depth = np.array(density_at_depth)
    time_array = np.array(valid_times).astype('datetime64[ns]')
    delta_density = np.diff(density_at_depth)
    delta_time = time_array[1:]

    # ─── Plot ─────────────────────────────────────────
    fig, ax = plt.subplots(2, 1, figsize=(10, 6), sharex=True)

    ax[0].plot(time_array, density_at_depth, marker='o', label=f'Density at {depth_target} m')
    ax[0].set_ylabel("σ₀ (kg/m³)")
    ax[0].set_title(f"Potential Density at {depth_target} m — {station_id}")
    ax[0].grid(True)

    ax[1].plot(delta_time, delta_density, marker='o', color='tab:red', label='Δσ₀')
    ax[1].set_ylabel("Δσ₀ between profiles")
    ax[1].set_xlabel("Time")
    ax[1].set_title("Change in Density Between Profiles")
    ax[1].grid(True)

    plt.tight_layout()

plot_density_change_at_depth("FZH08", ds)

In [None]:
def plot_density_derivative_at_depth(station_id, ds, depth_target=300, years=[2024, 2025]):
    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    import gsw

    # ─── Filter dataset ───────────────────────────────
    ds_station = ds.where(ds['station'] == station_id, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)

    if ds_station.sizes['row'] < 2:
        print(f"No valid data for station {station_id} in years {years}.")
        return

    # ─── Compute sigma_theta ──────────────────────────
    sigma_theta = gsw.sigma0(ds_station['salinity'], ds_station['temperature'])

    # ─── Interpolate to fixed depth ───────────────────
    times = np.unique(ds_station['time'].values)
    density_at_depth = []
    valid_times = []

    for t in times:
        profile = ds_station.where(ds_station['time'] == t, drop=True)
        if profile.sizes['row'] < 2:
            continue
        depth_vals = profile['depth'].values
        sigma_vals = sigma_theta.where(ds_station['time'] == t, drop=True).values
        if np.nanmin(depth_vals) <= depth_target <= np.nanmax(depth_vals):
            interpolated = np.interp(depth_target, depth_vals, sigma_vals, left=np.nan, right=np.nan)
            density_at_depth.append(interpolated)
            valid_times.append(t)

    if len(density_at_depth) < 2:
        print("Not enough valid profiles for derivative.")
        return

    # ─── Compute dσ/dt (finite difference) ────────────
    density_at_depth = np.array(density_at_depth)
    time_array = np.array(valid_times).astype('datetime64[ns]')
    delta_days = np.diff(time_array).astype('timedelta64[s]').astype(float) / 86400  # in days
    d_sigma_dt = np.diff(density_at_depth) / delta_days

    # ─── Plot ─────────────────────────────────────────
    fig, ax = plt.subplots(2, 1, figsize=(10, 6), sharex=True)

    ax[0].plot(time_array, density_at_depth, marker='o', label=f'σ₀ at {depth_target} m')
    ax[0].set_ylabel("σ₀ (kg/m³)")
    ax[0].set_title(f"Potential Density at {depth_target} m — {station_id}")
    ax[0].grid(True)

    ax[1].plot(time_array[1:], d_sigma_dt, marker='o', color='tab:red', label='dσ₀/dt')
    ax[1].set_ylabel("dσ₀/dt (kg/m³/day)")
    ax[1].set_xlabel("Time")
    ax[1].set_title("Rate of Density Change (dσ₀/dt)")
    ax[1].grid(True)

    plt.tight_layout()

plot_density_derivative_at_depth("FZH08", ds)

In [None]:

import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import warnings
warnings.filterwarnings("ignore") # For those annoying runtime warnings
import cartopy.feature as cfeature
import numpy as np
import os
import cmocean as cm

# Load dataset
file_path = os.path.expanduser('~/Desktop/Summer 2025 Python/Hakai_calvert.nc')
ds = xr.open_dataset(file_path)

def plot_density_derivative_profile(station_id, ds, years=[2024, 2025], depth_bin=1, ylim = ((350, 0)) ):
    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    import gsw
    import xarray as xr
    from matplotlib.dates import DateFormatter
    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})

    # ─── Filter dataset ─────────────────────────────
    ds_station = ds.where(ds['station'] == station_id, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)

    if ds_station.sizes['row'] < 2:
        print(f"No valid data for station {station_id} in years {years}.")
        return

    # ─── Compute potential density σ₀ ────────────────
    sigma_theta = gsw.sigma0(ds_station['salinity'], ds_station['temperature'])

    # ─── Interpolate onto common depth grid ─────────
    times = np.unique(ds_station['time'].values)
    depth_grid = np.arange(0, 500, depth_bin)
    sigma_interp = []
    time_list = []

    for t in times:
        profile = ds_station.where(ds_station['time'] == t, drop=True)
        if profile.sizes['row'] < 2:
            continue

        z = profile['depth'].values
        sigma = sigma_theta.where(ds_station['time'] == t, drop=True).values

        if np.nanmin(z) < 10:
            interpolated = np.interp(depth_grid, z, sigma, left=np.nan, right=np.nan)
            sigma_interp.append(interpolated)
            time_list.append(t)

    if len(sigma_interp) < 2:
        print("Not enough valid profiles for derivative.")
        return

    # ─── Compute dσ₀/dt ─────────────────────────────
    sigma_interp = np.array(sigma_interp)
    time_array = np.array(time_list).astype('datetime64[ns]')
    delta_months = np.diff(time_array).astype('timedelta64[D]').astype(float) / 30.44  # avg days/month
    d_sigma_dt = np.diff(sigma_interp, axis=0) / delta_months[:, np.newaxis]
    time_mid = time_array[1:]

    # ─── Plotting ───────────────────────────────────
    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))

    pc = ax.pcolormesh(time_mid, depth_grid, d_sigma_dt.T,
                       cmap='RdBu_r', shading='auto',
                       vmin=-0.4, vmax=0.4)

    ax.invert_yaxis()
    ax.set_title(f"dσ₀/dt at all depths — {station_id}")
    ax.set_xlabel("Date")
    ax.set_ylabel("Depth (m)")
    ax.set_ylim(ylim)
    fig.colorbar(pc, ax=ax, label="dσ₀/dt (kg/m³/month)")

    # ─── Time formatting ────────────────────────────
    ax.xaxis.set_major_formatter(DateFormatter('%b %d\n%Y'))

    ax_top = ax.secondary_xaxis('top')
    ax_top.set_xticks(time_mid)
    ax_top.set_xticklabels([pd.Timestamp(t).strftime('%b %-d') for t in time_mid], rotation=90, fontsize=8)
    ax_top.tick_params(axis='x', direction='out', length=3, width=0.5, pad=2)
    plt.tight_layout()

# plot_station_timeseries("KC10", ds, years=(2021, 2022, 2023), ylim= (350, 0))
# plot_density_derivative_profile("KC10", ds, years=[2021, 2022, 2023])

plot_station_timeseries("KC10", ds, years=(2020), ylim= (350, 0))
plot_station_timeseries("FZH08", ds, years=(2020), ylim= (350, 0))
plot_station_timeseries("HKP04", ds, years=(2020), ylim= (350, 0))

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

def load_and_plot_bakun(filepath, years=[2023], window=7):
    df = pd.read_csv(
        filepath,
        delim_whitespace=True,
        skiprows=1,
        names=['date', 'index']
    )
    # Drop missing data
    df = df[df['index'] != -9999]
    df['date'] = pd.to_datetime(df['date'].astype(str), format='%Y%m%d')
    df['year'] = df['date'].dt.year

    # Filter by target years
    df_filtered = df[df['year'].isin(years)].copy()
    if df_filtered.empty:
        print(f"No data for years: {years}")
        return

    # Resample to daily and interpolate
    df_filtered = df_filtered.set_index('date').sort_index()
    full_range = pd.date_range(df_filtered.index.min(), df_filtered.index.max(), freq="D")
    df_full = df_filtered.reindex(full_range)
    df_full.index.name = 'date'
    df_full['index'] = df_full['index'].interpolate(limit_direction='both')

    # Rolling smoothing
    smoothed = df_full['index'].rolling(window=window, center=False, min_periods=1).mean()

    fig, ax = plt.subplots(figsize=(1.5 * 1.5 * 6.4, 1.5 * 4.8))
    if window < 7:
        ax.plot(df_full.index, df_full['index'], color='lightgray', label='Raw')
    ax.plot(df_full.index, smoothed, color='black', label=f'{window}-day smoothed')

    # Fill under the curve
    ax.fill_between(df_full.index, 0, smoothed, where=smoothed > 0, color='red', alpha=0.3, label='Upwelling')
    ax.fill_between(df_full.index, 0, smoothed, where=smoothed <= 0, color='blue', alpha=0.3, label='Downwelling')

    ax.axhline(0, color='black', linestyle='--')
    ax.set_title(f"{window}-day smoothed Bakun Index at 51°N 131°W")
    ax.set_xlabel("Date")
    ax.set_ylabel(r"Index (m$^3$/s per 100m coastline)")
    ax.grid(True)
    ax.legend()

    # Optional: set x-limits to Jan 1 to June 1 of first year
    tmin = pd.Timestamp(f"{years[0]}-01-01")
    tmax = pd.Timestamp(f"{years[-1]}-06-01")
    ax.set_xlim(tmin, tmax)

    plt.show()

# load_and_plot_bakun("~/Desktop/p05dayac.all", years=[2020], window = 7)
# load_and_plot_bakun("~/Desktop/p05dayac.all", years=[2020])
# load_and_plot_bakun("~/Desktop/p05dayac.all", years=[2020], window = 28)
# load_and_plot_bakun("~/Desktop/p05dayac.all", years=[2020, 2021, 2022, 2023, 2024, 2025], window = 28)
load_and_plot_bakun("~/Desktop/p05dayac.all", years=[2024, 2025], window = 7)
load_and_plot_bakun("~/Desktop/p05dayac.all", years=[2024, 2025], window = 30)

In [None]:
import pandas as pd

df = pd.read_csv(
    "~/Desktop/p05dayac.all",
    delim_whitespace=True,
    skiprows=1,
    names=['date', 'index']
)
df = df[df['index'] != -9999]
df['date'] = pd.to_datetime(df['date'].astype(str), format='%Y%m%d')
df['year'] = df['date'].dt.year

df_filtered = df[df['year'].isin([2020])].set_index('date')

# SORT the index!
df_filtered = df_filtered.sort_index()

# Rolling mean
smoothed = df_filtered['index'].rolling(window=7, center=True, min_periods=1).mean()

# Show date gaps
date_diffs = df_filtered.index.to_series().diff()
print(date_diffs)

In [None]:
def plot_monthly_derivatives_profile(station_id, ds, years=[2024, 2025], depth_bin=1, ylim=(350, 0)):
    import numpy as np
    import matplotlib.pyplot as plt
    import pandas as pd
    import gsw
    import xarray as xr
    from matplotlib.dates import DateFormatter
    plt.rcParams.update({
        'font.size': 12,
        'axes.titlesize': 20,
        'axes.labelsize': 20,
        'xtick.labelsize': 15,
        'ytick.labelsize': 15,
        'legend.fontsize': 20,
        'figure.titlesize': 20})
    # ─── Filter and clean ─────────────────────────────
    ds_station = ds.where(ds['station'] == station_id, drop=True)
    ds_station = ds_station.dropna(dim='row', subset=['temperature', 'salinity', 'depth', 'time'])
    ds_station = ds_station.where(ds_station['time'].dt.year.isin(years), drop=True)

    if ds_station.sizes['row'] < 2:
        print(f"No valid data for station {station_id} in years {years}.")
        return

    # ─── Initialize arrays ────────────────────────────
    times = np.unique(ds_station['time'].values)
    depth_grid = np.arange(0, 500, depth_bin)
    temp_interp, sal_interp, sigma_interp = [], [], []
    time_list = []

    for t in times:
        profile = ds_station.where(ds_station['time'] == t, drop=True)
        if profile.sizes['row'] < 2:
            continue

        z = profile['depth'].values
        T = profile['temperature'].values
        S = profile['salinity'].values
        sigma = gsw.sigma0(S, T)

        if np.nanmin(z) < 10:
            temp_interp.append(np.interp(depth_grid, z, T, left=np.nan, right=np.nan))
            sal_interp.append(np.interp(depth_grid, z, S, left=np.nan, right=np.nan))
            sigma_interp.append(np.interp(depth_grid, z, sigma, left=np.nan, right=np.nan))
            time_list.append(t)

    if len(temp_interp) < 2:
        print("Not enough valid profiles for derivative.")
        return

    # ─── Convert to arrays ─────────────────────────────
    time_array = np.array(time_list).astype('datetime64[ns]')
    time_mid = time_array[1:]
    delta_months = np.diff(time_array).astype('timedelta64[D]').astype(float) / 30.44

    temp_interp = np.array(temp_interp)
    sal_interp = np.array(sal_interp)
    sigma_interp = np.array(sigma_interp)

    dT_dt = np.diff(temp_interp, axis=0) / delta_months[:, np.newaxis]
    dS_dt = np.diff(sal_interp, axis=0) / delta_months[:, np.newaxis]
    dSigma_dt = np.diff(sigma_interp, axis=0) / delta_months[:, np.newaxis]

    # ─── Plotting ──────────────────────────────────────
    fig, axs = plt.subplots(3, 1, figsize=(1.5 * 1.5 * 6.4, 3* 1.5 * 4.8), sharey=True)

    titles = ['dT/dt (°C/month)', 'dS/dt (psu/month)', 'dσ₀/dt (kg/m³/month)']
    data = [dT_dt, dS_dt, dSigma_dt]
    cmaps = [cm.cm.balance, cm.cm.balance, cm.cm.balance]
    vlims = [(-0.6, 0.6), (-0.5, 0.5), (-0.5, 0.5)]

    for ax, field, title, cmap, (vmin, vmax) in zip(axs, data, titles, cmaps, vlims):
        pc = ax.pcolormesh(time_mid, depth_grid, field.T,
                           cmap=cmap, shading='auto', vmin=vmin, vmax=vmax)
        ax.invert_yaxis()
        ax.set_ylim(ylim)
        ax.set_title(title)
        ax.set_xlabel("Date")
        ax.xaxis.set_major_formatter(DateFormatter('%b %d\n%Y'))

        # Top axis with rotated ticks
        ax_top = ax.secondary_xaxis('top')
        ax_top.set_xticks(time_mid)
        ax_top.set_xticklabels([pd.Timestamp(t).strftime('%b %-d') for t in time_mid], rotation=90)
        ax_top.tick_params(axis='x', direction='out', length=3, width=0.5, pad=2)

        fig.colorbar(pc, ax=ax, label=title)

    axs[0].set_ylabel("Depth (m)")
    fig.suptitle(f"Monthly Rate of Change at {station_id}")
    plt.tight_layout()

plot_monthly_derivatives_profile("FZH08", ds, years=[2020])