In [None]:
import xarray as xr
import numpy as np
import pandas as pd

from scipy.stats import moment


from diffusion_nextsim.deformation import estimate_deform

import matplotlib.pyplot as plt
import matplotlib.colors as mpl_c
import matplotlib.gridspec as mpl_gs
import matplotlib.dates as mpl_dates
import cartopy.crs as ccrs
import cartopy
import cmocean

In [None]:
plt.style.use("paper")
plt.style.use("wiley")

# Load data

In [None]:
ds_nextsim = xr.open_dataset("data/consistency_nextsim.nc")
ds_deterministic = xr.open_dataset("data/consistency_deterministic.nc")
ds_diffusion = xr.open_dataset("data/consistency_diffusion.nc")

## Estimate deformation

In [None]:
deform_nextsim = (estimate_deform(ds_nextsim) * 86400).compute()
deform_deterministic = (estimate_deform(ds_deterministic) * 86400).compute()
deform_diffusion = (estimate_deform(ds_diffusion) * 86400).compute()

# Scaling analysis

In [None]:
base_res = 12
coarsening = 2**np.arange(0, 7)
resolution = coarsening * base_res
markers = [".", "+", "*"]

In [None]:
ds_scaling = np.stack([
    moment(
        deform_nextsim["deform_tot"].coarsen(x=c, y=c).mean().values.flatten(),
        moment=np.arange(1, 4),
        nan_policy="omit", center=0
    )
    for c in coarsening
])
det_scaling = np.stack([
    moment(
        deform_deterministic["deform_tot"].coarsen(x=c, y=c).mean().values.flatten(),
        moment=np.arange(1, 4),
        nan_policy="omit", center=0
    )
    for c in coarsening
])
diff_scaling = np.stack([
    moment(
        deform_diffusion["deform_tot"].coarsen(x=c, y=c).mean().values.flatten(),
        moment=np.arange(1, 4),
        nan_policy="omit", center=0
    )
    for c in coarsening
])

In [None]:
log_res = np.log10(resolution)
log_res_mean = log_res.mean()
log_res_perts = log_res-log_res_mean
log_res_diff = np.diff(log_res)

log_ds_scaling = np.log10(ds_scaling)
log_ds_scaling_mean = log_ds_scaling.mean(axis=0)
log_ds_scaling_perts = log_ds_scaling-log_ds_scaling_mean
log_ds_coeff_mean = -np.linalg.lstsq(log_res_perts[:, None], log_ds_scaling_perts)[0].squeeze()
log_ds_coeff = -np.diff(log_ds_scaling, axis=0) / log_res_diff[:, None]

log_det_scaling = np.log10(det_scaling)
log_det_scaling_mean = log_det_scaling.mean(axis=0)
log_det_scaling_perts = log_det_scaling-log_det_scaling_mean
log_det_coeff_mean = -np.linalg.lstsq(log_res_perts[:, None], log_det_scaling_perts)[0].squeeze()
log_det_coeff = -np.diff(log_det_scaling, axis=0) / log_res_diff[:, None]

log_diff_scaling = np.log10(diff_scaling)
log_diff_scaling_mean = log_diff_scaling.mean(axis=0)
log_diff_scaling_perts = log_diff_scaling-log_diff_scaling_mean
log_diff_coeff_mean = -np.linalg.lstsq(log_res_perts[:, None], log_diff_scaling_perts)[0].squeeze()
log_diff_coeff = -np.diff(log_diff_scaling, axis=0) / log_res_diff[:, None]

In [None]:
fig, ax = plt.subplots(nrows=2, figsize=(3, 4), dpi=150)

ax[0].grid()
for i in range(3):
    ax[0].plot(
        resolution,
        np.power(10, log_ds_scaling_mean-log_res_perts[:, None] * log_ds_coeff_mean),
        c="black", ls="-", zorder=2, lw=0.75
    )
    ax[0].plot(
        resolution,
        np.power(10, log_det_scaling_mean-log_res_perts[:, None] * log_det_coeff_mean),
        c="#81B3D5", ls="--", zorder=2, lw=0.75
    )
    ax[0].plot(
        resolution,
        np.power(10, log_diff_scaling_mean-log_res_perts[:, None] * log_diff_coeff_mean),
        c="#9E62A6", ls="-", zorder=2, lw=0.75
    )
    ax[0].scatter(resolution, ds_scaling[:, i], c="black", marker=markers[i], zorder=3)
    ax[0].scatter(resolution, det_scaling[:, i], c="#81B3D5", marker=markers[i], zorder=3)
    ax[0].scatter(resolution, diff_scaling[:, i], c="#9E62A6", marker=markers[i], zorder=3)

ax[0].set_ylabel(r"$\langle \dot{\epsilon}^{q}_{\mathsf{tot}} \rangle$ $(\mathsf{day}^{-q})$")
ax[0].set_yscale("log")
ax[0].set_ylim(5E-5, 2E-1)
ax[0].set_xlabel("Spatial scale (km)")
ax[0].set_xlim(10, 1000)
ax[0].set_xscale("log")

ax[0].text(870, 0.045, "q=1")
ax[0].text(870, 0.0025, "q=2")
ax[0].text(870, 0.0002, "q=3")
ax[0].text(0.02, 0.98, "(a)", ha="left", va="top", transform=ax[0].transAxes)

ax[1].grid(ls="dotted", lw=0.5)
ax[1].errorbar(
    np.arange(1, 4),
    log_ds_coeff_mean,
    yerr=np.abs(np.quantile(log_ds_coeff, q=np.array([0, 1]), axis=0)-log_ds_coeff_mean),
    c="black", capsize=4, label="neXtSIM", zorder=97
)
ax[1].errorbar(
    np.arange(1, 4),
    log_det_coeff_mean,
    yerr=np.abs(np.quantile(log_det_coeff, q=np.array([0, 1]), axis=0)-log_det_coeff_mean),
    c="#81B3D5", capsize=4, ls="--", label="Deterministic", zorder=98
)
ax[1].errorbar(
    np.arange(1, 4),
    log_diff_coeff_mean,
    yerr=np.abs(np.quantile(log_diff_coeff, q=np.array([0, 1]), axis=0)-log_diff_coeff_mean),
    c="#9E62A6", capsize=4, label="ResDiffusion", zorder=99
)
ax[1].text(0.02, 0.98, "(b)", ha="left", va="top", transform=ax[1].transAxes)
ax[1].set_xlabel(r"Moment q")
ax[1].set_ylabel(r"Structure function $\beta(q)$")
ax[1].set_xticks([1, 2, 3])
ax[1].legend(framealpha=True, loc=2, bbox_to_anchor=(0.1, 1.05))
fig.subplots_adjust(hspace=0.4)
fig.align_ylabels(ax)

fig.savefig("figures/fig_08_multiscaling.png", dpi=300,)