In [None]:
import matplotlib.pyplot as plt
import xarray as xr
import numpy as np

In [None]:
# Settings
xstart_sub = 16_000
ystart_sub = 16_000

xsize_sub = 32_000
ysize_sub = 32_000

xend_sub = xstart_sub + xsize_sub
yend_sub = ystart_sub + ysize_sub

In [None]:
def read_crosses(var, mode):
    ds1 = xr.open_dataset(f'test/dom0/{var}.{mode}.nc', decode_times=False)
    ds2 = xr.open_dataset(f'test/dom1/{var}.{mode}.nc', decode_times=False)
    return ds1, ds2


def plot_xy(var, t, k):
    ds1, ds2 = read_crosses(var, 'xy')

    v1 = ds1[var][t,k,:,:]
    v2 = ds2[var][t,k,:,:]

    x1 = ds1['xh'] if var == 'u' else ds1['x']
    x2 = ds2['xh'] if var == 'u' else ds2['x']

    y1 = ds1['yh'] if var == 'v' else ds1['y']
    y2 = ds2['yh'] if var == 'v' else ds2['y']

    x2 = x2 + xstart_sub
    y2 = y2 + ystart_sub

    vmin = min(v1.min(), v2.min())
    vmax = max(v1.max(), v2.max())

    plt.figure(layout='constrained')
    plt.title(f'{var}: k={k}, t={t}')
    plt.pcolormesh(x1, y1, v1, vmin=vmin, vmax=vmax)
    plt.pcolormesh(x2, y2, v2, vmin=vmin, vmax=vmax)
    plt.colorbar()

    x = [xstart_sub, xstart_sub+xsize_sub, xstart_sub+xsize_sub, xstart_sub, xstart_sub]
    y = [ystart_sub, ystart_sub, ystart_sub+ysize_sub, ystart_sub+ysize_sub, ystart_sub]

    plt.plot(x, y, 'r:')


def plot_xz(var, t, j, zmax):
    ds1, ds2 = read_crosses(var, 'xz')

    v1 = ds1[var][t,:,j,:]
    v2 = ds2[var][t,:,j,:]

    x1 = ds1['xh'] if var == 'u' else ds1['x']
    x2 = ds2['xh'] if var == 'u' else ds2['x']

    z1 = ds1['zh'] if var == 'w' else ds1['z']
    z2 = ds2['zh'] if var == 'w' else ds2['z']

    x2 = x2 + xstart_sub

    vmin = min(v1.min(), v2.min())
    vmax = max(v1.max(), v2.max())

    plt.figure(layout='constrained')
    plt.title(f'{var}: j={j}, t={t}')
    plt.pcolormesh(x1, z1, v1, vmin=vmin, vmax=vmax)
    plt.pcolormesh(x2, z2, v2, vmin=vmin, vmax=vmax)
    plt.colorbar()

    zt = z1.max()
    plt.plot([xstart_sub, xstart_sub], [0, z1.max()], 'r:')
    plt.plot([xend_sub, xend_sub], [0, z1.max()], 'r:')

    plt.ylim(0, zmax)


plot_xy('thl', t=-1, k=1)
plot_xy('qt', t=-1, k=1)
plot_xy('u', t=-1, k=1)
plot_xy('v', t=-1, k=1)
plot_xy('w', t=-1, k=1)

plot_xz('thl', t=-1, j=0, zmax=1250)
plot_xz('qt', t=-1, j=0, zmax=1250)
plot_xz('u',  t=-1, j=0, zmax=1250)
plot_xz('v',  t=-1, j=0, zmax=1250)
plot_xz('w',  t=-1, j=0, zmax=1250)
