In [None]:
import os
import sys
from tqdm import tqdm
import numpy as np
import torch
sys.path.append(os.path.dirname(os.getcwd()))

from src.dynamics import KolmogorovFlow
from src.measurements import RandomMask
from src.utils import langevin_sampler

device = torch.device("cuda:0")

In [None]:
"""Parameters"""
grid_size = 128
dt = 0.2
reynolds = 1e3
noise_std = 0.3
sparsity = 0.9
steps = 50

n_sample = 400

lmc_steps = 500
lmc_stepsize = 1e-3
anneal_init = 1e-3
anneal_decay = 0.5
anneal_steps = 1

# meta parameters
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

dynamics = KolmogorovFlow(
    grid_size=grid_size,
    reynolds=reynolds,
    dt=dt,
    seed=seed,
)

measurement = RandomMask(noise_std=noise_std, sparsity=sparsity)

# generate a chain of ground-truth and corresponding measurements
x0: torch.Tensor = dynamics.prior(n_sample=1).to(device)  # (1, 2, grid_size, grid_size)
states: torch.Tensor = dynamics.generate(
    x0=x0,
    steps=steps+50,
)[50:, ...]  # (steps+1, 2, grid_size, grid_size)
observations: torch.Tensor = measurement.measure(states)

# prior initial states
prior: torch.Tensor = dynamics.prior(n_sample=n_sample).to(device)  # (n_train, 2, grid_size, grid_size)

_, *shape = prior.shape
assimilated_states = torch.empty((steps+1, n_sample, *shape), device=device)
# init_state_dict = model.state_dict()

with tqdm(range(steps+1), maxinterval=50.0, desc="state step", file=sys.stdout) as pbar:
    for i in pbar:
        grad_potential_fn = lambda x: -measurement.score_likelihood(x, observations[i]) 
        posterior = langevin_sampler(
            grad_potential_fn=grad_potential_fn,
            x=prior,
            steps=lmc_steps,
            dt=lmc_stepsize,
            anneal_init=anneal_init,
            anneal_decay=anneal_decay,
            anneal_steps=anneal_steps,
        )  # (n_train, *shape)
        assimilated_states[i] = posterior
        prior = dynamics.transition(posterior)
    
        mean_estimation = torch.mean(posterior, dim=0)  # (*shape, )
        median_estimation = torch.median(posterior, dim=0)[0]  # (*shape, )
        mean_rmse = torch.sqrt(torch.mean((mean_estimation - states[i]) ** 2))
        median_rmse = torch.sqrt(torch.mean((median_estimation - states[i]) ** 2))
        pbar.set_postfix(
            {
                "mean(RMSE)": mean_rmse.item(),
                "median(RMSE)": median_rmse.item(),
            },
            refresh=False,
        )

np.savez(
    f"../kolmogorov_results/mle_randmask_{sparsity:.1f}.npz",
    states=states.cpu().numpy(),  # (steps, 2, grid_size, grid_size)
    observations=observations.cpu().numpy(),
    assimilated_states=assimilated_states.cpu().numpy(),  # (steps, n_train, 2, grid_size, grid_size)
)

