In [57]:
import sys
sys.path.append("..")

from tqdm import tqdm
from dapper.mods.Lorenz96 import dstep_dx, step, x0
import dapper.mods as modelling
import dapper as dpr
import dapper.da_methods as da
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from hpl.datamodule.observational_models import RandomObservationModel
from hpl.datamodule.DataLoader import L96InferenceDataset
from mdml_tools.simulators import L96Simulator
from torch.utils.data import DataLoader

In [19]:
random_seed: int = 261197


grid_size: int = 36
forcing: float = 10.0
process_noise: float = 0.0
time_step: float = 0.01

observation_noise: float = 1.0
mask_fraction: float = 0.75
steps_between_observation: int = 1
number_of_observations: int = 2000
number_of_steps_to_cut: int = 500

device = "cpu"
pretrained_network_checkpoint = "/gpfs/work/zinchenk/final_experiments/140424_imperfect_forward_operator/pretrain_data_assimilation/multirun/2024-03-14/03-19-59/0/logs/checkpoints/assimilation_network.ckpt"
trained_network_checkpoint = "/gpfs/work/zinchenk/final_experiments/300424_parametrization_learning/parametrization_learning/multirun/2024-04-30/13-17-52/0/logs/checkpoints/assimilation_network.ckpt"
# path to simulator checkpoint with coupled parametrization
simulator_checkpoint = "/gpfs/work/zinchenk/final_experiments/300424_parametrization_learning/parametrization_learning/multirun/2024-04-30/13-17-52/0/logs/checkpoints/simulator.ckpt"

In [20]:
np.random.seed(random_seed)
torch.manual_seed(random_seed)

<torch._C.Generator at 0x7f36168ca070>

In [21]:
pretrained_network = torch.load(pretrained_network_checkpoint, map_location=device)
trained_network = torch.load(trained_network_checkpoint, map_location=device)
simulator = torch.load(simulator_checkpoint, map_location=device)
parametrization = simulator.parametrization

In [41]:
initial_state_X = x = forcing * (0.5 + torch.randn(torch.Size((1, 1, grid_size)), device="cpu") * 1.0)
initial_state_y = forcing * (0.5 + torch.randn(torch.Size((1, 1, grid_size, 10)), device="cpu") * 1.0)
initial_state_X /= torch.max(torch.tensor([10, 50]))
initial_state_y /= torch.max(torch.tensor([10, 50]))

simulator = L96Simulator(simulator_type='two_level', forcing=forcing)
time_to_simulate = torch.arange(0, (number_of_observations+number_of_steps_to_cut)*time_step, time_step, device='cpu', dtype=torch.float32)
ground_truth, _ = simulator.integrate(time_to_simulate, (initial_state_X, initial_state_y))
ground_truth = ground_truth.squeeze()[number_of_steps_to_cut:, :]

In [67]:
observation_operator = RandomObservationModel(
    additional_noise_std=observation_noise, 
    random_mask_fraction=mask_fraction,
)

dataset = L96InferenceDataset(
    ground_truth_data=ground_truth.unsqueeze(0).to("cpu"),
    observation_model=observation_operator,
    input_window_extend=25,
    drop_edge_samples=True,
)
dataset.to(device)

loader = DataLoader(dataset, batch_size=number_of_observations, shuffle=False)

In [66]:
reconstruction = {}
for batch in tqdm(loader):
    with torch.no_grad():
        analysis = pretrained_network.forward(batch)
        reconstruction["no_param"] = analysis.squeeze()
        analysis = trained_network.forward(batch)
        reconstruction["param"] = analysis.squeeze()

100%|███████████████████████████████████████████| 1/1 [00:01<00:00,  1.08s/it]


In [68]:
Nx = grid_size
Force = forcing
x_initial = x0(grid_size)
dynamical_operator = {
    'M': Nx,
    'model': step,
    'linear': dstep_dx,
    'noise': process_noise,
}

tseq = modelling.Chronology(dt=time_step, dko=steps_between_observation, Ko=number_of_observations, Tplot=0, BurnIn=5)

initial_state_X = x = forcing * (0.5 + torch.randn(torch.Size((1, 1, grid_size)), device="cpu") * 1.0)
initial_state_y = forcing * (0.5 + torch.randn(torch.Size((1, 1, grid_size, 10)), device="cpu") * 1.0)
initial_state_X /= torch.max(torch.tensor([10, 50]))
initial_state_y /= torch.max(torch.tensor([10, 50]))

simulator = L96Simulator(simulator_type='two_level', forcing=forcing)
time_to_simulate = torch.arange(0, (number_of_observations+2)*time_step, time_step, device='cpu', dtype=torch.float32)
xx, _ = simulator.integrate(time_to_simulate, (initial_state_X, initial_state_y))
xx = xx.squeeze()

obs_mod = RandomObservationModel(additional_noise_std=observation_noise, random_mask_fraction=mask_fraction)
observations, mask_array = obs_mod.forward(xx)
xx = xx.numpy()
observations = observations.numpy()
mask_array = np.array(mask_array.numpy(), dtype=bool)
mask_array = mask_array[1:]
observations = observations[1:]

yy = [observations[i, mask_array[i]] for i in range(observations.shape[0])]
yy = np.array(yy)

def observation_operator_dapper(t):
    jj = np.flatnonzero(mask_array[t])
    obs = modelling.Operator(**modelling.partial_Id_Obs(grid_size, jj), noise=observation_noise)
    return obs

X0 = modelling.GaussRV(mu=x_initial, C=10)
HMM = modelling.HiddenMarkovModel(dynamical_operator, dict(time_dependent=observation_operator_dapper), tseq, X0)

xps = da.OptInterp()
xps.assimilate(HMM, xx, yy)
optimal_interpolation = xps.stats.mu.a

xps = da.Persistence()
xps.assimilate(HMM, xx, yy)
persistence = xps.stats.mu.a

xps = da.EnKS(upd_a="PertObs", N=100, Lag=25)
xps.assimilate(HMM, xx, yy)
enks_perturb_obs = xps.stats.mu.a
