In [None]:
import os
import sys
import pickle
import json
import numpy as np
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
sys.path.append(os.path.dirname(os.getcwd()))

from src.measurements import RandomMask, GridMask

device = torch.device("cpu")

In [None]:
grid_size = 128

def vorticity(x):
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
    *batch, _, h, w = x.shape
    y = x.reshape(-1, 2, h, w)
    y = torch.nn.functional.pad(y, pad=(1, 1, 1, 1), mode="circular")
    (du,) = torch.gradient(y[:, 0], dim=-1)
    (dv,) = torch.gradient(y[:, 1], dim=-2)
    y = du - dv
    y = y[:, 1:-1, 1:-1]
    y = y.reshape(*batch, h, w)
    return y.numpy()

def get_mask(observation_type, sparsity=None, stride=None):
    if "center" in observation_type:
        mask = np.ones((grid_size, grid_size))
        mask[grid_size//4: 3*grid_size//4, grid_size//4: 3*grid_size//4] = np.inf
        return mask
    if "random" in observation_type:
        measurement = RandomMask(noise_std=noise_std, sparsity=sparsity)
        mask = measurement.mask.numpy()
        mask[mask == 0] = np.inf
        return mask
    if "grid" in observation_type:
        measurement = GridMask(noise_std=noise_std, stride=stride)
        mask = measurement.mask.numpy()
        mask[mask == 0] = np.inf
        return mask
    return 1.

filtered_folders = [
    "../kolmogorov_results/kolmogorov-241102-045641",
    "../kolmogorov_results/kolmogorov-241102-045709",
    "../kolmogorov_results/kolmogorov-241108-143710",
]


measurement_types = [
    "8x average pooling",
    "25\% center mask",
    "90%\% grid mask",
]

df = []
for workdir, measurement_type in zip(filtered_folders, measurement_types):
    with open(os.path.join(workdir, 'config.json'), 'r') as f:
        cfg = json.load(f)
    noise_std = cfg['measurement']['noise_std']
    sparsity = cfg['measurement'].get("sparsity", None)
    stride = cfg['measurement'].get("stride", None)
    n_train = cfg['train']['n_train']
    results = np.load(os.path.join(workdir, 'results.npz'))
    states = results['states'] # (steps, 2, grid_size, grid_size)
    observations = results['observations'] # (steps, 2, grid_size, grid_size)
    assimilated_states = results['assimilated_states'] # (steps, nsamples, 2, grid_size, grid_size)
    mean_estimation = np.mean(assimilated_states, axis=1) # (steps, 2, grid_size, grid_size)
    average_rmse = np.mean((states - mean_estimation)**2, axis=(1, 2, 3))**0.5 # (steps, )
    mean_vorticity = vorticity(mean_estimation) # (steps, grid_size, grid_size)
    mask = get_mask(measurement_type, sparsity=sparsity, stride=stride)
    df.append({
        "measurement_type": measurement_type,
        "observations": observations,
        "observations_vorticity": vorticity(observations)*mask,
        "mean_estimation": mean_estimation,
        "mean_vorticity": mean_vorticity,
        "average_rmse": average_rmse,
    })
states_vorticity = vorticity(states)
with open("../asset/Kolmogorov_evolution.pkl", "wb") as file:
    pickle.dump((df, states_vorticity), file)

In [None]:
with open("../asset/Kolmogorov_evolution.pkl", "rb") as file:
    df, states_vorticity = pickle.load(file)
mpl.rcdefaults()
mpl.style.use("../configs/mplrc")
mpl.rc("figure.subplot", wspace=-0.2, hspace=0.1)
mpl.rc("axes.spines", bottom=False, left=False)

nrows = 7
ncols = 5
freq = 10
fig, axes = plt.subplots(
    nrows=nrows,
    ncols=ncols+1, 
    figsize=(7, 7),
    gridspec_kw={
        "width_ratios": [1, ] * ncols + [0.9, ]
    }
)

vmin_state, vmax_state = -0.6, 0.6
state_norm = mpl.colors.Normalize(vmin=vmin_state, vmax=vmax_state)

vmin_obs = [-2.5, -0.5, -0.5]
vmax_obs = [2.5, 0.5, 0.5]
obs_norms = []
for i, data in enumerate(df):
    obs_norms.append(mpl.colors.Normalize(vmin=vmin_obs[i], vmax=vmax_obs[i]))

for j, ax in enumerate(axes[0][:-1]):
    ax.imshow(states_vorticity[(j+1)*freq], cmap=sns.cm.icefire, norm=state_norm)
    ax.xaxis.set_visible(False)
    ax.yaxis.set_ticks([])
    if j == 0:
        ax.set_ylabel(r"$\bf{x}$")

for data, ax, obs_norm in zip(df, axes[1::2], obs_norms):
    for j, axj in enumerate(ax[:-1]):
        axj.imshow(data["observations_vorticity"][(j+1)*freq], cmap=sns.cm.icefire, norm=obs_norm)
        axj.xaxis.set_visible(False)
        axj.yaxis.set_ticks([])
        if j == 0:
            axj.set_ylabel(r"$\bf{y}$")

for data, ax in zip(df, axes[2::2]):
    for j, axj in enumerate(ax[:-1]):
        axj.imshow(data["mean_vorticity"][(j+1)*freq], cmap=sns.cm.icefire, norm=state_norm)
        axj.xaxis.set_visible(False)
        axj.yaxis.set_ticks([])
        if j == 0:
            axj.set_ylabel(r"$\bf{\hat{x}}$")

for ax in axes:
    ax[-1].axis("off")
    
for ax in axes[::2]:
    cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=state_norm, cmap=sns.cm.icefire), ax=ax[-1], fraction=.8, aspect=5, shrink=.9)
    cbar.ax.tick_params(labelsize=6)
for obs_norm, ax in zip(obs_norms, axes[1::2]):
    cbar = fig.colorbar(mpl.cm.ScalarMappable(norm=obs_norm, cmap=sns.cm.icefire), ax=ax[-1], fraction=.8, aspect=5, shrink=.9)
    cbar.ax.tick_params(labelsize=6)

plt.savefig('../asset/Kolmogorov_evolution.pdf', dpi=600, bbox_inches='tight', pad_inches=0.)
# plt.savefig('../asset/Kolmogorov_evolution.png', dpi=600, bbox_inches='tight', pad_inches=0.)