In [None]:
import os
import sys
import pickle
import json
import numpy as np
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
sys.path.append(os.path.dirname(os.getcwd()))

from src.measurements import RandomMask

device = torch.device("cpu")

In [None]:
def vorticity(x):
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
    *batch, _, h, w = x.shape
    y = x.reshape(-1, 2, h, w)
    y = torch.nn.functional.pad(y, pad=(1, 1, 1, 1), mode="circular")
    (du,) = torch.gradient(y[:, 0], dim=-1)
    (dv,) = torch.gradient(y[:, 1], dim=-2)
    y = du - dv
    y = y[:, 1:-1, 1:-1]
    y = y.reshape(*batch, h, w)
    return y.numpy()

workdir = "../kolmogorov_results/kolmogorov-241102-100415"
with open(os.path.join(workdir, 'config.json'), 'r') as f:
    cfg = json.load(f)
noise_std = cfg['measurement']['noise_std']
steps = cfg['dynamics']['steps']
grid_size = cfg['dynamics']['grid_size']
kernel_size = cfg['measurement']['kernel_size']
sparsity = cfg['measurement']['sparsity']
n_train = cfg['train']['n_train']

results = np.load(os.path.join(workdir, 'results.npz'))

states = results['states'] # (steps, 2, grid_size, grid_size)
observations = results['observations'] # (steps, 2, grid_size, grid_size)
measurement = RandomMask(noise_std=noise_std, sparsity=sparsity)
mask = measurement.mask.numpy()
mask[mask == 0] = np.inf
observations_vorticity = vorticity(observations)*mask # (steps, grid_size//r, grid_size//r)
states_vorticity = vorticity(states) # (steps, grid_size, grid_size)

ssls_states = results['assimilated_states'] # (steps, nsamples, 2, grid_size, grid_size)
ssls_mean = np.mean(ssls_states, axis=1) # (steps, 2, grid_size, grid_size)
ssls_vorticity_all = vorticity(ssls_states) # (steps, nsamples, grid_size, grid_size)
ssls_vorticity = np.mean(ssls_vorticity_all, axis=1) # (steps, grid_size, grid_size)
ssls_vorticity_bias = np.abs(ssls_vorticity - states_vorticity) # (steps, grid_size, grid_size)
ssls_vorticity_std = np.std(ssls_vorticity_all, axis=1) # (steps, grid_size, grid_size)

with open("../asset/Kolmogorov_uq.pkl", "wb") as file:
    pickle.dump((
        states_vorticity,
        observations_vorticity,
        ssls_vorticity,
        ssls_vorticity_bias,
        ssls_vorticity_std,
    ), file)

In [None]:
with open("../asset/Kolmogorov_uq.pkl", "rb") as file:
    states_vorticity, observations_vorticity, \
    ssls_vorticity, ssls_vorticity_bias, \
    ssls_vorticity_std = pickle.load(file)
    
mpl.rcdefaults()
mpl.style.use("../configs/mplrc")
mpl.rc("figure.subplot", wspace=-0.4, hspace=0.2)
mpl.rc("axes.spines", bottom=False, left=False)


ncols = 5
freq = 10
fig, axes = plt.subplots(
    nrows=5,
    ncols=ncols+1, 
    figsize=(7, 6.5),
    gridspec_kw={
        "width_ratios": [1, ] * ncols + [0.4, ]
    }
)

vmins = [-0.6, ssls_vorticity_bias.min(), ssls_vorticity_std.min()]
vmaxs = [0.6, 0.6, 0.4]
norms = []
for vmin, vmax in zip(vmins, vmaxs):
    norms.append(mpl.colors.Normalize(vmin=vmin, vmax=vmax))

for j, axj in enumerate(axes[0][:-1]):
    axj.imshow(states_vorticity[(j+1)*freq], cmap=sns.cm.icefire, norm=norms[0])
    axj.xaxis.set_visible(False)
    axj.yaxis.set_ticks([])
    if j == 0:
        axj.set_ylabel(r"$\bf{x}$")

for j, axj in enumerate(axes[1][:-1]):
    axj.imshow(observations_vorticity[(j+1)*freq], cmap=sns.cm.icefire, norm=norms[0])
    axj.xaxis.set_visible(False)
    axj.yaxis.set_ticks([])
    if j == 0:
        axj.set_ylabel(r"$\bf{y}$")

for j, axj in enumerate(axes[2][:-1]):
    axj.imshow(ssls_vorticity[(j+1)*freq], cmap=sns.cm.icefire, norm=norms[0])
    axj.xaxis.set_visible(False)
    axj.yaxis.set_ticks([])
    if j == 0:
        axj.set_ylabel(r"$\bf{\hat{x}}$")

for j, axj in enumerate(axes[3][:-1]):
    axj.imshow(ssls_vorticity_bias[(j+1)*freq], cmap=sns.cm.icefire, norm=norms[1])
    axj.xaxis.set_visible(False)
    axj.yaxis.set_ticks([])
    if j == 0:
        axj.set_ylabel(r"$\bf{|\hat{x} - x|}$")

for j, axj in enumerate(axes[4][:-1]):
    axj.imshow(ssls_vorticity_std[(j+1)*freq], cmap=sns.cm.icefire, norm=norms[2])
    axj.xaxis.set_visible(False)
    axj.yaxis.set_ticks([])
    if j == 0:
        axj.set_ylabel("std-dev")

cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norms[0], cmap=sns.cm.icefire), ax=axes[:3, -1], fraction=.5, aspect=20, shrink=1.)
cbar.ax.tick_params(labelsize=6)
cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norms[1], cmap=sns.cm.icefire), ax=axes[3, -1], fraction=.5, aspect=6.66, shrink=1.)
cbar.ax.tick_params(labelsize=6)
cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=norms[2], cmap=sns.cm.icefire), ax=axes[4, -1], fraction=.5, aspect=6.66, shrink=1.)
cbar.ax.tick_params(labelsize=6)

for ax in axes:
    ax[-1].axis("off")

plt.savefig('../asset/Kolmogorov_uq.pdf', dpi=600, bbox_inches='tight', pad_inches=0.1)