In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cf
import numpy as np

In [None]:
# import fields and geostrophic wind

%store -r fields
[uWind,vWind,wWind,temp,geop,div,vort,geop_height] = fields

%store -r grid
[lon,lat,pressure_levels] = grid

%store -r uWindG
%store -r vWindG

The thermal wind is calculated according to
\begin{align}
\vec{v}_t(p_i) = \vec{v}_g(p_i) - \vec{v}_g(p_{i-1})
\end{align}

In [None]:
uWindT = np.zeros_like(uWindG)
vWindT = np.zeros_like(uWindT)
for p in range(1,len(pressure_levels)):
    uWindT[p,:,:] = uWindG[p,:,:] - uWindG[p-1,:,:]
    vWindT[p,:,:] = vWindG[p,:,:] - vWindG[p-1,:,:]

In [None]:
def therm_wind(field,N=90,S=90,W=0,E=360,pressure_level=-2,spacing=5):
    '''this function plots field with an areal extend of [N,S,W,E] at pressure_level
       with the thermal wind displayed as arrows.
       W is given in degrees east and has to be smaller than E,
       also no negative values are allowed.
       spacing is the space inbetween arrows in degrees.
       the geopotential height is displayed as contour lines (in m).
    '''
    
    N = 90-N
    S = 90+S

    fig, ax = plt.subplots(figsize=(15,8), subplot_kw={'projection': ccrs.PlateCarree()})
    im = ax.contourf(lon[W:E], lat[N:S], field[1,pressure_level,N:S,W:E],
                        cmap='viridis', levels=20)
    
    im2 = ax.contour(lon[W:E], lat[N:S], geop_height[1,pressure_level,N:S,W:E])
    ax.clabel(im2, im2.levels, inline=True,colors='k')

    Q2 = ax.quiver(lon[W:E][::spacing], lat[N:S][::spacing],
                  uWindT[pressure_level,N:S,W:E][::spacing,::spacing],
                  vWindT[pressure_level,N:S,W:E][::spacing,::spacing])
    Qk2 = ax.quiverkey(Q2, 0.5,-0.15,np.nanmax(uWindT[pressure_level,N:S,W:E][::spacing,::spacing]),
                       label="{:.0f}".format(np.nanmax(uWindT[pressure_level,N:S,W:E][::spacing,::spacing])) + 'm/s thermal wind velocity',
                       labelpos = 'E')

    ax.add_feature(cf.COASTLINE)
    ax.add_feature(cf.BORDERS)
    ax.set_xticks([0],[0])
    ax.set_yticks([0],[0])

    fig.colorbar(im, orientation='horizontal', fraction=0.039*len(lon)/len(lat), label=f"{field.long_name} [{field.units}]")
    ax.set_title(f"{field.long_name} at p = {pressure_levels[pressure_level]} hPa", fontsize=20)
    fig.tight_layout()

### Use the function to display the thermal wind in different regions and at different altitudes

In [None]:
pressure_levels # in hPa

In [None]:
therm_wind(temp,pressure_level=-3,spacing=7)

In [None]:
therm_wind(temp,N=80,S=0,W=270,E=360,pressure_level=-3,spacing=2)