# Inference in a PDE model



Look, we can do approximate probabilistic inference with a complicated neural network!



## Set up

In [None]:
import numpy as np
from math import sqrt, ceil

import matplotlib.pyplot as plt

import h5py

import torch
import torch.fft as fft
import torch.nn as nn
from torch.optim import AdamW
from torch.nn.modules.loss import MSELoss

from src.nn_modules.fourier_2d_generic import SimpleBlock2dGeneric
from src.heatmap import multi_heatmap
from src.utils import resolve_path

# device = torch.device('cuda')
device = torch.device('cpu')


This dataset is packed as as $(b, x, y, t)$ i.e. batch first, coordinates in the middle, time last. 


In [None]:
data_file = h5py.File(resolve_path('./data/grf_forcing_pico.h5'))
data = data_file['valid']

def get_obs(data, t=0, n_steps=2, y=False):
    x = data['u'][..., t:t+n_steps]
    latent = data['f'][...]
    obs= {
        'x': x,
        'latent': latent
    }
    if y:
        obs['y'] = data['u'][..., t+n_steps]

    return obs


In [None]:
def dict_as_tensor(d, device=device):
    """
    it was faster to write this function than to search the docs for it
    """
    return {k: torch.from_numpy(v).to(device) for k, v in d.items()}

In [None]:

def multi_img_time(x, batch=0, interval=1, n_cols=None, fsize=6):
    """
    Plot multiple timesteps of an array (time last)
    """
    steps = range(0, x.shape[-1], interval)
    if n_cols is None:
        n_cols = len(steps)
    print(steps, len(steps), n_cols)
    n_rows = ceil(len(steps) / n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fsize * n_cols/n_rows, fsize));
    axes = axes.flatten()
    for ax in axes:
        ax.set_axis_off()
    for i, ax in zip(steps, axes):
        ax.imshow(x[batch,..., i])
    plt.tight_layout()
    return fig

def multi_img_batch(x, interval=1, n_cols=None, fsize=6):
    """
    Plot multiple batches of an array 
    """
    steps = range(0, x.shape[0], interval)
    if n_cols is None:
        n_cols = len(steps)
    n_rows = ceil(len(steps) / n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fsize * n_cols/n_rows, fsize));
    axes = axes.flatten()
    for ax in axes:
        ax.set_axis_off()
    for i, ax in zip(steps, axes):
        ax.imshow(x[..., i])
    plt.tight_layout()
    return fig

In [None]:
arr = get_obs(data, 0, 8)['x']

multi_img_time(arr, n_cols=2);


## We have a process predictor model

In [None]:
pp_state_dict = torch.load(
    './models/fno_forward.ckpt',
    map_location=device
)
process_predictor = SimpleBlock2dGeneric(
    modes1=16,
    width=24,
    n_layers=4,
    n_history=2,
    param=False,
    forcing=False,
    latent=True,
)
process_predictor.load_state_dict(
    pp_state_dict
)

In [None]:
def predict_forward(x, latent, n_horizon=1, n_steps=2):
    x = torch.as_tensor(x).to(device)
    latent = torch.as_tensor(latent).to(device)
    for i in range(n_horizon):
        pred = process_predictor({'x': x[...,-(n_steps):], 'latent': latent})['forecast']
        x = torch.cat((x, pred), dim=-1)
    return x[...,-n_horizon:]

obs = get_obs(data, 0, 10)

with torch.no_grad():
    pred = predict_forward(obs['x'], obs['latent'], n_horizon=100, n_steps=2)

multi_img_time(pred.cpu().numpy(), n_cols=5, interval=5);


In [None]:
multi_img_time(pred.cpu().numpy(), n_cols=5, interval=5);
plt.savefig('./fno_forward_predict_sheet.jpg')

## inversion by GD

In [None]:
class RasterLatent(nn.Module):
    def __init__(self,
            process_predictor: "nn.Module",
            dims = (256,256),
            n_batch: int=1):
        super().__init__()
        self.dims = dims
        self.process_predictor = process_predictor
        ## Do not fit the process predictor weights
        for param in self.process_predictor.parameters():
            param.requires_grad = False
        self.latent = nn.Parameter(
            torch.zeros(
                (n_batch, *dims),
                dtype=torch.float32
            )
        )

    def weights_init(self):
        self.latent.data.normal_(0.0, 0.01)

    def forward(self, batch):
        #copy
        batch = dict(**batch)
        batch['latent'] = self.latent
        return self.process_predictor(batch)



In [None]:
def fit(
        batch,
        model,
        optimizer,
        n_iter: int=20,
        check_int: int=1,
        clip_val = None,
        callback = lambda *x: None):
    model.train()
    model.weights_init()
    loss_fn = nn.MSELoss()
    big_loss_fn = nn.MSELoss(reduction='none')
    scale = loss_fn(torch.zeros_like(batch['latent']), batch['latent']).item()
    for i in range(n_iter):
        # Compute prediction error
        pred = model(batch)
        loss = loss_fn(pred['forecast'], batch['y'])
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        if clip_val is not None:
            for group in optimizer.param_groups:
                torch.nn.utils.clip_grad_value_(group["params"], clip_val)

        optimizer.step()

        if i % check_int == 0 or i==n_iter-1:
            with torch.no_grad():
                loss_v = loss.item()
                # batchwise error
                big_loss_v = big_loss_fn(pred['forecast'], batch['y']).mean(dim=(1,2,3))
                big_error = big_loss_fn(model.latent, batch['latent']).mean(dim=(1,2))
                big_relerr = torch.sqrt(big_error/scale)
                error = big_error.mean().item()
                relerr = sqrt(big_relerr.mean().item())
                print(
                    f"loss: {loss:.3e}, error: {error:.3e}, relerror: {relerr:.3e} [{i:>5d}/{n_iter:>5d}]")

                target =  batch['latent'][0, :, :].cpu().numpy()
                est =  model.latent[0, :, :].cpu().numpy()
                err_heatmap = target - est
                fig = multi_heatmap(
                    [target, est, err_heatmap],
                    ["Target", "Estimate", "Error"])
                plt.show();
                plt.close("all");

    return loss_v, error, relerr, scale


In [None]:
model = RasterLatent(
    process_predictor,
    dims=obs['x'].shape[1:3],
    n_batch=1)
optimizer = AdamW(
    model.parameters(),
    lr=0.0025,
    weight_decay=0.0)

loss_fn = nn.MSELoss()

fit(
    dict_as_tensor(get_obs(data,t=0,n_steps=2,y=True)),
    model,
    optimizer,
    n_iter=50,
    check_int=10,
    clip_val=None,
)

In [None]:
model = RasterLatent(
    process_predictor,
    dims=obs['x'].shape[1:3],
    n_batch=1)
optimizer = AdamW(
    model.parameters(),
    lr=0.00025,
    weight_decay=10)

loss_fn = nn.MSELoss()

fit(
    dict_as_tensor(get_obs(data,t=0,n_steps=2,y=True)),
    model,
    optimizer,
    n_iter=50,
    check_int=10,
    clip_val=None,
)

## Probabilistic version

* [Modules in Pyro — Pyro Tutorials](https://pyro.ai/examples/modules.html)
* [Neural Networks — Pyro documentation](https://docs.pyro.ai/en/stable/nn.html#pyro.nn.module.PyroSample)

In [None]:
import pyro
from pyro.nn import PyroModule, PyroParam, PyroSample
from pyro.nn.module import to_pyro_module_
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, Predictive, MCMC, NUTS
from pyro import poutine


In [None]:
class ProbRasterLatent(PyroModule):
    def __init__(
            self,
            process_predictor: "nn.Module",
            dims = (256,256),
            prior_scale = 0.01,
            obs_scale = 0.01,):
        super().__init__()
        self.dims = dims
        self.prior_scale = prior_scale
        self.obs_scale = obs_scale
        self.process_predictor = process_predictor
        process_predictor.train(False)
        ## Do not fit the process predictor weights
        for param in self.process_predictor.parameters():
            param.requires_grad = False
        self.latent = PyroSample(dist.Normal(0, 0.01).expand(dims).to_event(2))

    def forward(self, X, y=None):
        #overwrite process predictor batch with my own latent
        mean = self.process_predictor({
            'x': X,
            'latent': self.latent.unsqueeze(0),
        })['forecast']
        return pyro.sample(
            "obs", dist.Normal(mean, self.obs_scale).to_event(2),
            obs=y)

model = ProbRasterLatent(
    process_predictor,
    dims=obs['x'].shape[1:3],
    prior_scale=0.01,
    obs_scale=0.01,
)

nuts_kernel = NUTS(model, full_mass=False, max_tree_depth=5, jit_compile=True) # high performacne params

mcmc = MCMC(nuts_kernel, num_samples=100, warmup_steps=100)
obs = get_obs(data, 0, 2, y=True)
mcmc.run(torch.as_tensor(obs['x']), torch.as_tensor(obs['y']))
mc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}


In [None]:
np.savez_compressed("data/samples1.npz", mc_samples)

In [None]:
multi_img_batch(mc_samples['latent'], interval=20)

## offcuts

In [None]:
class Fourier2dMapping(nn.Module):
    """
    Does not work because I tried to do something fancy with parameterizations.
    TODO.
    """

    def __init__(self, modes: int=20, dims=(256,256), prior_scale=0.01, obs_scale=0.01):
        super().__init__()
        self.modes = modes  # maybe just normalize the weights?
        self.dims = dims
        self.prior_scale = prior_scale
        self.obs_scale = obs_scale

    def forward(self, X):
        """
        map from complex inputs on a half space to real inputs on a full space
        """
        print("X", X.shape, X.dtype)
        return fft.irfft2(X, s=self.dims, norm="ortho")

    def right_inverse(self, Xp):
        """
        map from real inputs on a full space to complex inputs on a half space
        """
        return fft.rfft2(Xp, s=self.dims, norm="ortho")

