# Inference in a PDE model



Look we can do approximate probabilistic inference with a complicated nueral network!

## Set up

In [None]:
import numpy as np
from math import sqrt, ceil
from typing import Any, Dict, Tuple, Optional

import matplotlib.pyplot as plt
from matplotlib import rc
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('png')
# plt.rcParams.update({figure.figsize'=[12, 12]})
# plt.rcParams.update({'figure.dpi': 200})
plt.rcParams.update({'font.size': 20})
%matplotlib inline

from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv());

import h5py

import torch
import torch.fft as fft
import torch.nn as nn
from torch.optim import AdamW

from src.nn_modules.fourier_2d_generic import SimpleBlock2dGeneric
from src.heatmap import multi_heatmap
from src.utils import resolve_path

# device = torch.device('cuda')
device = torch.device('cpu')


This dataset is packed as as $(b, x, y, t)$ i.e. batch first, coordinates in the middle, time last. 


In [None]:
data_file = h5py.File(resolve_path('./data/grf_forcing_pico.h5'))
data = data_file['valid']

def get_obs(data, t=0, n_steps=2):
    x = data['u'][..., t:t+n_steps]
    latent = data['f'][...]
    return {
        'x': x,
        'latent': latent
    }


In [None]:

def multi_img_time(x, batch=0, interval=1, n_cols=None, fsize=6):
    """
    Plot multiple timesteps of an array (time last)
    """
    steps = range(0, x.shape[-1], interval)
    if n_cols is None:
        n_cols = len(steps)
    print(steps, len(steps), n_cols)
    n_rows = ceil(len(steps) / n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fsize * n_cols/n_rows, fsize));
    axes = axes.flatten()
    for ax in axes:
        ax.set_axis_off()
    for i, ax in zip(steps, axes):
        ax.imshow(x[batch,..., i])
    plt.tight_layout()
    return fig


In [None]:
arr = get_obs(data, 0, 10)['x']

multi_img_time(arr, n_cols=2);


## We have a process predictor model

In [None]:
pp_state_dict = torch.load(
    './models/fno_forward.ckpt',
    map_location=device
)
process_predictor = SimpleBlock2dGeneric(
    modes1=16,
    width=24,
    n_layers=4,
    n_history=2,
    param=False,
    forcing=False,
    latent=True,
)
process_predictor.load_state_dict(
    pp_state_dict
)

In [None]:
def predict_forward(x, latent, n_horizon=1, n_steps=2):
    x = torch.as_tensor(x).to(device)
    latent = torch.as_tensor(latent).to(device)
    for i in range(n_horizon):
        pred = process_predictor({'x': x[...,-(n_steps):], 'latent': latent})['forecast']
        x = torch.cat((x, pred), dim=-1)
    return x[...,-n_horizon:]

obs = get_obs(data, 0, 10)

with torch.no_grad():
    pred = predict_forward(obs['x'], obs['latent'], n_horizon=100, n_steps=2)

In [None]:
multi_img_time(pred.cpu().numpy(), n_cols=5, interval=5);


## inversion by GD

In [None]:

def plot_heatmap(model, i, loss, error, loss_fn, batch, pred, *args, **kwargs):
    target =  batch['latent'][0, :, :].cpu().numpy()
    est =  model.latent[0, :, :].cpu().numpy()
    err_heatmap = target - est

    fig = heatmap.multi_heatmap(
        [target, est, err_heatmap],
        ["Target", "Estimate", "Error"], *args, **kwargs)
    # plt.savefig(f"paper_ml4ps/inverse_reg_{i}.png",
    #     dpi=300, bbox_inches='tight', pad_inches=0)
    # np.savez_compressed(f"paper_ml4ps/inverse_reg_{i}.npz",
    #     target=target,
    #     est=est,
    #     error=err_heatmap
    # )
    plt.show();
    plt.close("all");

In [None]:
class NaiveLatent(nn.Module):
    def __init__(self,
            process_predictor: "nn.Module",
            dims: Tuple[int, int]=(256,256),
            n_batch: int=1):
        super().__init__()
        self.dims = dims
        self.process_predictor = process_predictor
        ## Do not fit the process predictor weights
        for param in self.process_predictor.parameters():
            param.requires_grad = False
        self.latent = nn.Parameter(
            torch.zeros(
                (n_batch, *dims),
                dtype=torch.float32
            )
        )

    def weights_init(self):
        self.latent.data.normal_(0.0, 0.01)

    def forward(self, batch):
        #copy
        batch = dict(**batch)
        batch['latent'] = self.latent
        return self.process_predictor(batch)



In [None]:
def fit(
        batch,
        model,
        loss_fn,
        optimizer,
        n_iter:int=20,
        check_int:int=1,
        clip_val: Optional[float] = None,
        callback = lambda *x: None,
        # pen_0: float = 0.0,
        pen_1: float = 0.0,
        pen_f: float = 0.0,
        stop_on_truth: bool = False,
        diminishing_returns=1.1,):
    model.train()
    model.weights_init()
    prev_loss_v = 10^5
    prev_error = 10^5
    prev_relerr = 10^3
    big_losses = []
    big_loss_fn = MSELoss(reduction='none')
    big_scale = big_loss_fn(torch.zeros_like(batch['latent']), batch['latent']).mean((1,2))
    scale = loss_fn(torch.zeros_like(batch['latent']), batch['latent']).item()
    for i in range(n_iter):
        # Compute prediction error
        pred = model(batch)
        loss = loss_fn(pred['forecast'], batch['y'])
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        if clip_val is not None:
            for group in optimizer.param_groups:
                torch.nn.utils.clip_grad_value_(group["params"], clip_val)

        optimizer.step()
        # sch = self.lr_schedulers()
        # sch.step()

        if i % check_int == 0 or i==n_iter-1:
            with torch.no_grad():
                # recalc without penalties
                loss_v = loss_fn(pred['forecast'], batch['y']).item()
                if loss_v > diminishing_returns * prev_loss_v and i> 15:
                    print("Early stopping at optimum")
                    break
                prev_loss_v = loss_v
                error = loss_fn(model.latent, batch['latent']).item()
                if error > diminishing_returns * prev_error and stop_on_truth:
                    print("Early stopping at minimum prediction error")
                    break
                prev_error = error
                relerr = sqrt(error/scale)
                ##
                
                big_loss_v = big_loss_fn(pred['forecast'], batch['y']).mean((1,2,3))
                print(big_loss_v.shape)
                big_error = big_loss_fn(model.latent, batch['latent']).mean((1,2))
                big_relerr = torch.sqrt(big_error/scale)
                big_losses.append(dict(
                    big_loss=big_loss_v.detach().cpu().numpy(),
                    big_error = big_error.detach().cpu().numpy(),
                    relerr=big_relerr.detach().cpu().numpy()
                ))

                print(
                    f"loss: {loss:.3e}, error: {error:.3e}, relerror: {relerr:.3e} [{i:>5d}/{n_iter:>5d}]")
                callback(model, i, loss_v, error, loss_fn, batch, pred)

    loss_v = loss.item()
    error = loss_fn(model.latent, batch['latent']).item()
    scale = loss_fn(torch.zeros_like(batch['latent']), batch['latent']).item()
    relerr = sqrt(error/scale)
    print(
        f"loss: {loss:.3e}, error: {error:.3e}, relerror: {relerr:.3e} scale: {scale:.3e}[{i:>5d}/{n_iter:>5d}]")

    return loss_v, error, relerr, scale, big_losses, big_scale.detach().cpu().numpy()



In [None]:
class Fourier2dMapping(nn.Module):
    """
    Does not work because I tried to do something fancy with parameterizations.
    TODO.
    """

    def __init__(self, modes: int=20, dims: Tuple[int, int]=(256,256)):
        super().__init__()
        self.modes = modes  # maybe just normalize the weights?
        self.dims = dims

    def forward(self, X):
        """
        map from complex inputs on a half space to real inputs on a full space
        """
        print("X", X.shape, X.dtype)
        return fft.irfft2(X, s=self.dims, norm="ortho")

    def right_inverse(self, Xp):
        """
        map from real inputs on a full space to complex inputs on a half space
        """
        return fft.rfft2(Xp, s=self.dims, norm="ortho")



## Probabilistic version

http://pyro.ai/examples/mle_map.html
http://pyro.ai/examples/modules.html