# Exercise 1: Profile Data

Aim: To work with vertical profile data and make some standard calculations.

- Author: Yves
- Purpose: Plot profile data
- Date: 2024-04-02

This **worked**!  But *not sure* what to do next.

<hr>

In [1]:
import warnings
warnings.filterwarnings('ignore') # hide warnings in this notebook

In [2]:
import matplotlib.pyplot as plt
import pandas as pd
import gsw 
import numpy as np
import xarray as xr
import os
import pycnv

from matplotlib.ticker import ScalarFormatter
from seabird.cnv import fCNV
from seabird.netcdf import cnv2nc
from datetime import datetime
#from modules.nsq import adiabatic_leveling
import mixsea as mx

ModuleNotFoundError: No module named 'pycnv'

In [None]:
# Input file paths
input_path_cnv = '../shared_data/MSM121_054_1db.cnv'
input_path_cnv = '../data/MSM121_054_1db.cnv'

# Output file paths for netCDF
output_path_netcdf = '../data/MSM121_054_1db.nc'
output_path_netcdf_extended = '../data/MSM121_054_1db_extended.nc'

# Output file paths for figures
output_path_profile_1 = './figures/Ex1-Sorge-profile-channel1-efw.png'
output_path_profile_2 = './figures/Ex1-Sorge-profile-channel2-efw.png'
output_path_ts = './figures/Ex1-Sorge-ts-diagram-efw.png'
output_path_ts_with_depth = './figures/Ex1-Sorge-ts-diagram-with-depth-efw.png'
output_path_buoyancy_frequency_linear_scale = './figures/Ex1-Sorge-buoyancy-frequency-linear-efw.png'
output_path_buoyancy_frequency_log_scale = './figures/Ex1-Sorge-buoyancy-frequency-log-efw.png'
output_path_dynamic_steric_height = './figures/Ex1-Sorge-dynamic-steric-height-efw.png'


In [None]:
# Number of the station
station_number = 54

## Working with CNV original file

### Loading the file

In [None]:
# Read data via pycnv
cnv1 = pycnv.pycnv(input_path_cnv)

In [None]:
# Print some info to the screen
print('Test if we are in the Baltic Sea (usage of different equation of state): ' + str(cnv1.baltic))
print('Position of cast is: Longitude:', cnv1.lon,'Latitude:',cnv1.lat)
print('Time of cast was:', cnv1.date)
print('Number of sensor entries (len(cnv.data.keys())):',len(cnv1.data.keys()))
print('Names of sensor entries (cnv.data.keys()):',cnv1.data.keys())

### Printing the data

In [None]:
key0 = list(cnv1.data.keys())[0]
data0 = cnv1.data[key0]
print(cnv1.data)
print(cnv1.data.keys())
print(cnv1.units)


### Plotting the data

In [None]:
def plot_vertical_profile(pressure: list, pressure_label: str = 'Pressure [db]', parameters: list = [], 
                parameter_labels: list = [], output_filename: str = None, title: str = None):
    """ Plots salinity and temperature against pressure via two subplots. """

    # Consistency check
    if len(parameters) != len(parameter_labels):
        raise ValueError("Length of parameters list must be the same of parameter_labels list.")
    
    # Create figure for two subplots
    fig, axs = plt.subplots(1, len(parameters), figsize=(10,7), sharex=False)

    def plot_subplot(ax, values, values_label, pressure, pressure_label, index):
        """ Creates a subplot for a parameter """
        ax.plot(values, pressure, color='blue', linewidth=2)  # plot data values
        ax.set_xlabel(values_label, fontsize=12) # label for x-axis
        if index < 1:
            ax.set_ylabel(pressure_label, fontsize=12) # label for y-axis
        ax.invert_yaxis() # invert y-axis since we look down in the ocean
        ax.grid(True, color='gray', linestyle='--', linewidth=0.5) # show grid
        ax.tick_params(axis='both', labelsize=12) # format font size of ticks        

    # Iterate over list of parameters
    for index, parameter in enumerate(parameters):
        
        # Plot parameter as subplot
        plot_subplot(axs[index], parameter, parameter_labels[index], pressure, pressure_label, index)
        
    # Adjust space between subplots
    plt.subplots_adjust(wspace=(0.1+(len(parameters)-1)*0.2))

    # Set title for the entire figure if given
    if title:
        fig.suptitle(title, fontsize=14)
    
    # Save plot if file name is given
    if output_filename:
        plt.savefig(output_filename)

    # Show plot
    plt.show()

In [None]:
plot_vertical_profile(
    pressure = cnv1.p, 
    pressure_label = f"Pressure [{cnv1.p_unit}]",
    parameters = [
        cnv1.SA, # salinity
        cnv1.T, # temperature
        cnv1.data['sbeox0ML/L'], # oxygen
    ],
    parameter_labels = [
        f"Absolute salinity [{cnv1.SA_unit}]",
        f"Temperature [{cnv1.T_unit}]",
        f"Oxygen [{cnv1.units['sbeox0ML/L']}]"
    ],
    title = f"Station {station_number} ({cnv1.lat:.2f}°N, {cnv1.lon:.2f}°E)"
)

## Transforming to xarray

### Convert CNV file to netCDF

In [None]:
# Instead of using seabird package on CLI for conversion: 
# do it here programmatically, but check before whether the .nc file already exists 
if not os.path.exists(output_path_netcdf):
    data = fCNV(input_path_cnv) # read .cnv file
    cnv2nc(ds, output_path_netcdf) # write .nc file

### Open file and show contents

In [None]:
ds = xr.open_dataset(output_path_netcdf)
ds['PRES'].attrs['long_name'] = ds['PRES'].attrs['long_name'].replace(', Digiquartz', '') # remove digiquartz hint
print(ds.info())

## Calculating TEOS-10 parameters

Calculating absolute salinity and conservative temperature for both primary and secondary channels (sensors).


In [None]:
def get_SA(conductivity: list, temperature: list, pressure: list, longitude: float, latitude: float) -> list:
    """ Calculates absolute salinity """
    SP = gsw.SP_from_C(conductivity, temperature, pressure) # calculate practical salinity first
    SA = gsw.SA_from_SP(SP, pressure, longitude, latitude) # calculate absolute salinity
    return SA

def get_CT(SA: list, temperature: list, pressure: list) -> list:
    """ Calculates conservative temperature """
    CT = gsw.CT_from_t(SA, temperature, pressure) # calculate conservative temperature
    return CT

def get_PT(SA: list, temperature: list, pressure: list) -> list:
    """ Calculates potential temperature """
    PT = gsw.pt0_from_t(SA, temperature, pressure)
    return PT



In [None]:
# Calculate depth
ds['DEPTH'] = gsw.z_from_p(ds['PRES'], lat=ds.attrs['LATITUDE']) # convert pressure to depth

### Channel/Sensors 1

In [None]:
# Calculate Absolute Salinity
ds['SA1'] = (('scan',), get_SA(ds['c0mSPercm'].values, ds['TEMP'].values, 
    ds['PRES'].values, ds.attrs['LONGITUDE'], ds.attrs['LATITUDE']))
ds['SA1'].attrs['long_name'] = 'Practical Salinity [g/kg]'

# Calculate Conservative Temperature
ds['CT1'] = (('scan',), get_CT(ds['SA1'].values, ds['TEMP'].values, ds['PRES'].values))
ds['CT1'].attrs['long_name'] = 'Cons. Temperature [ITS-90, deg C]'

# Calculate Potential Temperature
ds['PT1'] = (('scan',), get_PT(ds['SA1'].values, ds['TEMP'].values, ds['PRES'].values))
ds['PT1'].attrs['long_name'] = 'Potential Temperature [ITS-90, deg C]'


In [None]:
plot_vertical_profile(
    pressure = ds['PRES'], 
    pressure_label = ds['PRES'].attrs['long_name'],
    parameters = [
        ds['SA1'], # salinity
        ds['CT1'], # temperature
        ds['oxygen_ml_L'], # oxygen
    ],
    parameter_labels = [
        ds['CT1'].attrs['long_name'],
        ds['SA1'].attrs['long_name'],
        ds['oxygen_ml_L'].attrs['long_name']
    ],
    title = f"Station {station_number} ({ds.attrs['LATITUDE']:.2f}°N, {ds.attrs['LONGITUDE']:.2f}°E)",
    output_filename = output_path_profile_1
)

### Channel/Sensors 2

In [None]:
# Calculate Absolute Salinity
ds['SA2'] = (('scan',), get_SA(ds['c1mSPercm'].values, ds['TEMP2'].values, 
    ds['PRES'].values, ds.attrs['LONGITUDE'], ds.attrs['LATITUDE']))
ds['SA2'].attrs['long_name'] = 'Practical Salinity 2 [g/kg]'

# Calculate Conservative Temperature
ds['CT2'] = (('scan',), get_CT(ds['SA2'].values, ds['TEMP2'].values, ds['PRES'].values))
ds['CT2'].attrs['long_name'] = 'Cons. Temperature 2 [ITS-90, deg C]'

# Calculate Potential Temperature
ds['PT2'] = (('scan',), get_PT(ds['SA2'].values, ds['TEMP2'].values, ds['PRES'].values))
ds['PT2'].attrs['long_name'] = 'Potential Temperature 2 [ITS-90, deg C]'

In [None]:
plot_vertical_profile(
    pressure = ds['PRES'], 
    pressure_label = ds['PRES'].attrs['long_name'],
    parameters = [
        ds['SA2'], # salinity 2
        ds['CT2'], # temperature 2
        ds['sbeox1MLPerL'], # oxygen 2
    ],
    parameter_labels = [
        ds['CT2'].attrs['long_name'],
        ds['SA2'].attrs['long_name'],
        ds['sbeox1MLPerL'].attrs['long_name']
    ],
    title = f"Station {station_number} ({ds.attrs['LATITUDE']:.2f}°N, {ds.attrs['LONGITUDE']:.2f}°E)",
    output_filename = output_path_profile_2
)

### Saving xarray to netCDF file

In [None]:
ds.to_netcdf(output_path_netcdf_extended)

## T-S Diagram

A T-S diagram has salinity on the x-axis and temperature on the y-axis, and contours of density (sigma_0) added.

In [None]:

def plot_ts_diagram(salinity: list, temperature: list, depth: list = None, output_file: str = None, title: str = 'T-S Diagram', 
            dot_size: int = 70, show_density_isolines: bool = True, show_lines_between_dots: bool = True,
            label_yaxis: str = 'Temperature [°C]', label_xaxis: str = 'Salinity [g/kg]',
            show_grid: bool = True, figsize=(10, 7), linewidth: float = 0.5, linecolor: str = 'blue', 
            depth_colormap: str = None):
    """ Plots a T-S diagram. """

    def plot_density_isolines(temperature: list, salinity: list, plt):
        """ Plots density isolines into a given T-S diagram plot. """
    
        def calculate_padding(min_val, max_val, padding=0.1):
            """ Calculates padding with min and max values. """
            width = max_val - min_val
            return min_val - (width * padding), max_val + (width * padding)
        
        t_min, t_max = calculate_padding(np.min(temperature), np.max(temperature))
        s_min, s_max = calculate_padding(np.min(salinity), np.max(salinity))
    
        # Determine dimensions based on the wider range to maintain aspect ratio
        if (t_max - t_min) > (s_max - s_min):
            xdim, ydim = 150, int(np.round(150 * (s_max - s_min) / (t_max - t_min)))
        else:
            ydim, xdim = 150, int(np.round(150 * (t_max - t_min) / (s_max - s_min)))
    
        # Create temp and salt vectors of appropiate dimensions
        ti = np.linspace(t_min, t_max, ydim)
        si = np.linspace(s_min, s_max, xdim)
    
        # Use meshgrid to create a 2D grid of si and ti, then vectorize the density calculation
        SI, TI = np.meshgrid(si, ti)
        density = gsw.rho(SI, TI, 0) - 1000  # Subtract 1000 to convert density to sigma-t directly
    
        # Plot isolines
        CS = plt.contour(SI, TI, density, linewidths=1, linestyles='dashed', colors='gray')
        plt.clabel(CS, fontsize=10, inline=1, fmt='%1.2f')  # Label every second level
    
        # Add sigma_0 in gray in the left upper corner
        plt.text(0.02, 0.95, r"$\sigma_0$", color='gray', fontsize=20, 
                 fontweight='bold', transform=plt.gca().transAxes)
    
    # Create figure
    fig = plt.figure(figsize=figsize)

    # Create a line plot of temperature vs. salinity
    if show_lines_between_dots:
        plt.plot(salinity, temperature, linestyle='-', color=linecolor, linewidth=linewidth)

    # Create a scatter plot of temperature vs. salinity
    if depth_colormap == None:
        plt.scatter(salinity, temperature, s=dot_size, color='blue')
    else:
        plt.scatter(salinity, temperature, s=dot_size, c=depth, cmap=depth_colormap)
        plt.colorbar(label='Depth [m]') # Plot legend for colormap

    # Add grid lines to the plot for better readability
    if show_grid:
        plt.grid(color='gray', linestyle='--', linewidth=0.5)

    # Set font size for ticks
    plt.tick_params(axis='both', labelsize=12)
   
    # Set plot labels and title
    if title:
        plt.title(title, fontsize=14, pad=30)
    plt.xlabel(label_xaxis, fontsize=12)
    plt.ylabel(label_yaxis, fontsize=12)

    # Integrate density isolines if wanted
    if show_density_isolines:
        plot_density_isolines(temperature, salinity, plt)

    # Save the plot as file
    if output_file:
        plt.savefig(output_file)

    # Show the plot
    plt.show()

    

### Plot without depth

In [None]:
plot_ts_diagram(ds['SA2'].values, ds['PT2'].values, 
    show_density_isolines=True, 
    show_grid=False, 
    output_file=output_path_ts, dot_size=35,
    title = f"T-S Diagram for Station {station_number} ({ds.attrs['LATITUDE']:.2f}°N, {ds.attrs['LONGITUDE']:.2f}°E)",
    label_yaxis='Potential Temperature [°C]',
    label_xaxis='Absolute Salinity [g/kg]'
)

### Plot with depth

In [None]:
plot_ts_diagram(ds['SA2'].values, ds['PT2'].values, 
    show_density_isolines=True, 
    show_grid=False, 
    output_file=output_path_ts_with_depth, dot_size=75,
    title = f"T-S Diagram for Station {station_number} ({ds.attrs['LATITUDE']:.2f}°N, {ds.attrs['LONGITUDE']:.2f}°E)",
    label_yaxis='Potential Temperature [°C]',
    label_xaxis='Absolute Salinity [g/kg]',
    depth_colormap='winter', depth=(ds['DEPTH']),
    linecolor='gray'
)

## Buoyancy Frequency

Buoyancy frequency is a measure of how strongly a vertical profile is stratified.

### Calculation

#### Method 1: Manually

Calculate density using the Gibb’s seawater toolbox [GSW density](https://teos-10.github.io/GSW-Python/density.html).

In [None]:
# Method 1: 
g = 9.81  # gravitational acceleration, m/s^2
rho_0 = 1025  # reference density, kg/m^3
depth = ds['DEPTH']
rho = gsw.rho(ds['SA2'], ds['CT2'], ds['PRES']) # calculate in-situ density
b = -g * (rho - rho_0) / rho_0 # calculate buoyancy

# Numerically differentiate buoyancy with respect to depth to get N^2
# Using central differences for interior points and forward/backward differences for the endpoints
N2_method1 = np.zeros_like(b)
N2_method1[1:-1] = (b[2:] - b[:-2]) / (depth[2:] - depth[:-2])
N2_method1[0] = (b[1] - b[0]) / (depth[1] - depth[0])
N2_method1[-1] = (b[-1] - b[-2]) / (depth[-1] - depth[-2])

rho=rho*1000
N2_method1 = np.zeros_like(b)
N2_method1[1:-1] = (rho[2:] - rho[:-2]) / (depth[2:] - depth[:-2])
N2_method1[0] = (rho[1] - rho[0]) / (depth[1] - depth[0])
N2_method1[-1] = (rho[-1] - rho[-2]) / (depth[-1] - depth[-2])
N2_method1 = -g/rho_0*N2_method1
N2_method1 = N2_method1/1000

#b = b[1:5]
#b[0:4]=[0,6,9,12]
#depth[0:4]=[1,2,3,4]
##depth = depth[1:5]
#N2_method1 = np.zeros_like(b)
#N2_method1[1:-1] = (b[2:] - b[:-2]) / (depth[2:] - depth[:-2])
#N2_method1[0] = (b[1] - b[0]) / (depth[1] - depth[0])
#N2_method1[-1] = (b[-1] - b[-2]) / (depth[-1] - depth[-2])

#print(b[2:] - b[:-2])
#print(depth[2:] - depth[:-2])
#print(N2_method1[1:-1])

In [None]:
list(ds.keys())

#### Method 2: GSW function

Using the GSW function for buoyancy frequency `gsw.Nsquared()` under [GSW stability](https://teos-10.github.io/GSW-Python/stability.html).

In [None]:
# Method 2: Calculate buoyancy frequency (N^2) using GSW's function
N2_method2, p_mid = gsw.Nsquared(ds['SA2'], ds['CT2'], ds['PRES'], lat=ds.attrs['LATITUDE']) # function returns N^2 and the midpoint pressure levels (p_mid)

#### Method 3: Adiabatic Leveling
As a more advanced method 3, the Bray and Fofonoff (1981) method - sometimes called "Fofonoff levelling" or here "adiabatic levelling" - is used. It smoothes the density locally: function `nsq.adiabatic_levelling` from [mixsea](https://github.com/modscripps/mixsea/blob/main/mixsea/nsq.py) with order 1 polynomial:

In [None]:
# Method 3: Applying adiabatic leveling
N2_method3 = mx.nsq.adiabatic_leveling(ds['PRES'].values, ds['PSAL2'].values, ds['TEMP2'].values, ds.attrs['LONGITUDE'], ds.attrs['LATITUDE'])
N2_method3b = mx.nsq.adiabatic_leveling(ds['PRES'].values, ds['SA2'].values, ds['CT2'].values, ds.attrs['LONGITUDE'], ds.attrs['LATITUDE'])


### Plotting

In [None]:
def plot_buoyancy_frequency(depth: list, depth_label: str = 'Depth [m]',
                            parameters: list = [],
                            parameter_labels: list = [],
                            output_filename: str = None, title: str = None, 
                            log_scale_for_xaxis: bool = False):
    """ Plots buoyancy frequency against depth via two subplots. """

    # Consistency check
    if len(parameters) != len(parameter_labels):
        raise ValueError("Length of parameters list must be the same of parameter_labels list.")
    
    # Create figure for three subplots
    fig, axs = plt.subplots(1, len(parameters), figsize=(10,7), sharex=True)

    def plot_subplot(ax, values: list, values_label: str, depth: list, depth_label: str, log_scale_for_axis: bool, index: int):
        """ Creates a subplot for a parameter """
        ax.plot(values, depth, color='blue')  # plot parameter values
        ax.set_xlabel(values_label, fontsize=12) # label for x-axis
        if index < 1:
            ax.set_ylabel(depth_label, fontsize=12) # label for y-axis
        ax.invert_yaxis() # invert y-axis since we look down in the ocean
        ax.grid(True, color='gray', linestyle='--', linewidth=0.5) # show grid
        ax.tick_params(axis='both', labelsize=12) # format font size of ticks
        if log_scale_for_xaxis:
            ax.set_xscale('log')  # Apply log-10 scale to x-axis      
            
    # Iterate over list of parameters
    for index, parameter in enumerate(parameters):

        # Plot parameter as subplot
        plot_subplot(axs[index], parameter, parameter_labels[index], depth, depth_label, log_scale_for_xaxis, index)
    
    # Adjust space between subplots
    plt.subplots_adjust(wspace=0.1+((len(parameters)-1)*0.2))

    # Set title for the entire figure if given
    if title:
        fig.suptitle(title, fontsize=14)
    
    # Save plot if file name is given
    if output_filename:
        plt.savefig(output_filename)

    # Show plot
    plt.show()

def plot_buoyancy_frequencyB(depth: list, depth_label: str = 'Depth [m]',
                            parameters: list = [],
                            parameter_labels: list = [],
                            output_filename: str = None, title: str = None, 
                            log_scale_for_xaxis: bool = False):
    """ Plots buoyancy frequency against depth via two subplots. """

    # Consistency check
    if len(parameters) != len(parameter_labels):
        raise ValueError("Length of parameters list must be the same of parameter_labels list.")
    
    # Create figure for three subplots
    fig, ax = plt.subplots(1, 1, figsize=(10,7))

    def plot_subplot(ax, values: list, values_label: str, depth: list, depth_label: str, log_scale_for_axis: bool, index: int):
        """ Creates a subplot for a parameter """
        ax.plot(values, depth)  # plot parameter values
        ax.set_xlabel(values_label, fontsize=12) # label for x-axis
        if index < 1:
            ax.set_ylabel(depth_label, fontsize=12) # label for y-axis
        ax.invert_yaxis() # invert y-axis since we look down in the ocean
        ax.grid(True, color='gray', linestyle='--', linewidth=0.5) # show grid
        ax.tick_params(axis='both', labelsize=12) # format font size of ticks
        if log_scale_for_xaxis:
            ax.set_xscale('log')  # Apply log-10 scale to x-axis      
            
    # Iterate over list of parameters
    for index, parameter in enumerate(parameters):

        # Plot parameter as subplot
        plot_subplot(ax, parameter, parameter_labels[index], depth, depth_label, log_scale_for_xaxis, index)
    
    # Adjust space between subplots
    plt.subplots_adjust(wspace=0.1+((len(parameters)-1)*0.2))

    # Set title for the entire figure if given
    if title:
        fig.suptitle(title, fontsize=14)
    
    # Save plot if file name is given
    if output_filename:
        plt.savefig(output_filename)

    # Show plot
    plt.show()

#### Using linear scale for x-axis

In [None]:
plot_buoyancy_frequency(
    np.abs(ds['DEPTH'][0:-1]),
    parameters = [
        N2_method1[0:-1], # method 1
        N2_method2, # method 2
        N2_method3[0:-1] # method 3
    ],
    parameter_labels = [
        'N2 (method 1) [rad²/s²]', 
        'N2 (method 2) [rad²/s²]', 
        'N2 (method 3) [rad²/s²]'
    ],
    title=f"Buoyancy Frequency for Station {station_number} ({ds.attrs['LATITUDE']:.2f}°N, {ds.attrs['LONGITUDE']:.2f}°E)",
    output_filename = output_path_buoyancy_frequency_linear_scale, 
    log_scale_for_xaxis = False
)


#### Using log-10 scale for x-axis

In [None]:
plot_buoyancy_frequencyB(
    np.abs(ds['DEPTH'][0:-1]),
    parameters = [
        N2_method2, # method 2
        N2_method1[0:-1], # method 1
        N2_method3[0:-1], # method 3 - from practical sal
        N2_method3b[0:-1] # method 3 - from SA
    ],
    parameter_labels = [
        'N2 (method 1) [rad²/s²]', 
        'N2 (method 2) [rad²/s²]', 
        'N2 (method 3) [rad²/s²]',
        'N2 (method 3) [rad²/s²]'
    ],
    title=f"Buoyancy Frequency for Station {station_number} ({ds.attrs['LATITUDE']:.2f}°N, {ds.attrs['LONGITUDE']:.2f}°E)",
    output_filename = output_path_buoyancy_frequency_log_scale, 
    log_scale_for_xaxis = True
)


## Dynamic and Steric Height

### Calculation

Dynamic height is related to the geostrophic current's strength and direction, providing a measure of the ocean's relative circulation patterns. The output from this function is typically in units of m^2s^-2, representing the geopotential difference between two points.

In [None]:
dynamic_height = gsw.geostrophy.geo_strf_dyn_height(ds['SA2'],ds['CT2'], ds['PRES'])

Steric height is a measure of the change in height of a water column due to changes in temperature and salinity, essentially representing the thermal expansion and contraction and the effect of salinity changes on the volume of a column of seawater. Steric height is expressed in meters (m).  To convert dynamic height in m^2s⁻2 to steric height in meters (m), the following relationship can be used which involves the acceleration due to gravity (g):

In [None]:
g = 9.81
steric_height = dynamic_height / g

### Plotting

In [None]:
def plot_dynamic_and_steric_height(dynamic_height, steric_height, depth, 
                depth_label='Depth [m]', output_filename=None, title=None):
    """ Plots dynamic height and steric height against depth via two subplots. """
    
    # Create figure for two subplots
    fig, axs = plt.subplots(1, 2, figsize=(10,7), sharex=False)

    def plot_subplot(ax, values: list, values_label: list, depth, depth_label: str, index: int):
        """ Creates a subplot for a parameter """
        ax.plot(values, depth, color='blue')  # plot parameter values
        ax.set_xlabel(values_label, fontsize=12) # label for x-axis
        if index < 1:
            ax.set_ylabel(depth_label, fontsize=12) # label for y-axis
        ax.invert_yaxis() # invert y-axis since we look down in the ocean
        ax.grid(True, color='gray', linestyle='--', linewidth=0.5) # show grid
        ax.tick_params(axis='both', labelsize=12) # format font size of ticks

    # Plot dynamic height
    plot_subplot(axs[0], dynamic_height, 'Dynamic height [m$^2$s$^{-2}$]', depth, depth_label, 0)

    # Plot steric height
    plot_subplot(axs[1], steric_height, 'Steric height [m$^2$s$^{-2}$]', depth, depth_label, 1)
    
    # Adjust space between subplots
    plt.subplots_adjust(wspace=0.3)

    # Set title for the entire figure if given
    if title:
        fig.suptitle(title, fontsize=14)
    
    # Save plot if file name is given
    if output_filename:
        plt.savefig(output_filename)

    # Show plot
    plt.show()



In [None]:
plot_dynamic_and_steric_height(dynamic_height, steric_height, np.abs(depth), 
        title=f"Dynamic and Steric Height for Station {station_number} ({ds.attrs['LATITUDE']:.2f}°N, {ds.attrs['LONGITUDE']:.2f}°E)",
        output_filename = output_path_dynamic_steric_height
)