# Inverting Hydrology with Neural Nets



In [None]:
%load_ext autoreload
%autoreload 2


## Set up

In [None]:
from src import minimise_predictive_loss


In [None]:
import os
from pathlib import Path
from dotenv import load_dotenv, find_dotenv
import numpy as np
import random
import functools
from math import sqrt
from PIL import Image

from typing import Any, Callable, Tuple

import torch

import matplotlib.pyplot as plt
from matplotlib import rc
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('png')
# plt.rcParams.update({'figure.figsize': [12, 12]})
# plt.rcParams.update({'figure.dpi': 200})
plt.rcParams.update({'font.size': 20})
%matplotlib inline
load_dotenv(find_dotenv());

In [None]:

from typing import Any, Dict, Tuple, Optional
from math import ceil, sqrt

import torch
import torch.nn.utils.parametrize as parametrize
import torch.fft as fft
import torch.nn as nn
from einops import rearrange, repeat
from torch.optim import AdamW

from fourierflow.nn_modules.loss import LpLoss
from fourierflow.nn_modules.fourier_2d_generic import SimpleBlock2dGeneric
from fourierflow.viz.heatmap import navier_stokes_heatmap, multi_heatmap
from fourierflow.datastores.navier_stokes_h5 import NavierStokesH5InstDatastore
from fourierflow.utils import resolve_path
from fourierflow.optimizers import AdamWC
from torch.linalg import vector_norm, matrix_norm

from scipy.optimize import bisect


In [None]:
device = torch.device('cuda')
# device = torch.device('cpu')


## inversion by GD

### unregularized

In [None]:
%aimport fourierflow._infer
infer = fourierflow._infer
%aimport fourierflow.viz.heatmap
heatmap = fourierflow.viz.heatmap

# def plot_heatmap(model, i, loss, error, loss_fn, batch, pred, *args, **kwargs):
#     err_heatmap = target - est
#     fig = heatmap.multi_heatmap([target, est, err_heatmap], ["target", "est", "error"], *args, **kwargs)
#     plt.show();
#     # plt.savefig(f"{i}.png")
#     plt.close("all");

def plot_heatmap(model, i, loss, error, loss_fn, batch, pred, *args, **kwargs):
    target =  batch['latent'][0, :, :].cpu().numpy()
    est =  model.latent[0, :, :].cpu().numpy()
    err_heatmap = target - est

    fig = heatmap.multi_heatmap(
        [target, est, err_heatmap],
        ["Target", "Estimate", "Error"], *args, **kwargs)
    plt.savefig(f"paper_ml4ps/inverse_{i}.png", dpi=300, bbox_inches='tight', pad_inches=0)
    np.savez_compressed(f"paper_ml4ps/inverse_{i}.npz",
        target=target,
        est=est,
        error=err_heatmap
    )

    plt.show();
    plt.close("all");


In [None]:
torch.manual_seed(57)


infer.main(
    callback=plot_heatmap,
    lr=0.005,
    weight_decay=0.0,
    n_iter=50,
    check_int=5,
    ds_args={
        'data_path': '${FNO_DATA_ROOT}/navier-stokes/grf_forcing_mini.h5',
        'ssr': 1,
        'n_history': 2,
        'n_horizon': 1,
        'batch_size': 20,
        'latent_key': 'f',
        'forcing_key': '',
        'param_key': '',
        'n_workers': 0,
    })

In [None]:
torch.manual_seed(57)

infer.main(
    callback=plot_heatmap,
    lr=0.005,
    weight_decay=0.0,
    n_iter=50,
    check_int=5,
    ds_args={
        'data_path': '${FNO_DATA_ROOT}/navier-stokes/grf_forcing_mini_1.h5',
        'ssr': 1,
        'n_history': 2,
        'n_horizon': 1,
        'batch_size': 20,
        'latent_key': 'f',
        'forcing_key': '',
        'param_key': '',
        'n_workers': 0,
    })

In [None]:

# model = infer.main(
#     callback=plot_heatmap, lr=0.1, weight_decay=0.0, n_iter=10, mapping='fourier')

### diff penalty

In [None]:

def plot_heatmap(model, i, loss, error, loss_fn, batch, pred, *args, **kwargs):
    target =  batch['latent'][0, :, :].cpu().numpy()
    est =  model.latent[0, :, :].cpu().numpy()
    err_heatmap = target - est

    fig = heatmap.multi_heatmap(
        [target, est, err_heatmap],
        ["Target", "Estimate", "Error"], *args, **kwargs)
    plt.savefig(f"paper_ml4ps/inverse_reg_{i}.png", dpi=300, bbox_inches='tight', pad_inches=0)
    np.savez_compressed(f"paper_ml4ps/inverse_reg_{i}.npz",
        target=target,
        est=est,
        error=err_heatmap
    )
    plt.show();
    plt.close("all");

torch.manual_seed(59)

model, fit, etc = infer.main(
     ds_args={
        'data_path': '${FNO_DATA_ROOT}/navier-stokes/grf_forcing_mini.h5',
        'ssr': 1,
        'n_history': 2,
        'n_horizon': 1,
        'batch_size': 20,
        'latent_key': 'f',
        'forcing_key': '',
        'param_key': '',
        'n_workers': 0,
    },
    callback=plot_heatmap, lr=0.01, weight_decay=0.00, n_iter=200, pen_1=30,check_int=5)

In [None]:

def plot_heatmap(model, i, loss, error, loss_fn, batch, pred, *args, **kwargs):
    target =  batch['latent'][0, :, :].cpu().numpy()
    est =  model.latent[0, :, :].cpu().numpy()
    err_heatmap = target - est

    fig = heatmap.multi_heatmap(
        [target, est, err_heatmap],
        ["Target", "Estimate", "Error"], *args, **kwargs)
    plt.savefig(f"paper_ml4ps/inverse_reg_{i}.png", dpi=300, bbox_inches='tight', pad_inches=0)
    np.savez_compressed(f"paper_ml4ps/inverse_reg_{i}.npz",
        target=target,
        est=est,
        error=err_heatmap
    )
    plt.show();
    plt.close("all");

torch.manual_seed(59)

model, fit, etc = infer.main(
     ds_args={
        'data_path': '${FNO_DATA_ROOT}/navier-stokes/grf_forcing_mini_1.h5',
        'ssr': 1,
        'n_history': 2,
        'n_horizon': 1,
        'batch_size': 20,
        'latent_key': 'f',
        'forcing_key': '',
        'param_key': '',
        'n_workers': 0,
    },
    callback=plot_heatmap, lr=0.01, weight_decay=0.00, n_iter=200, pen_1=30,check_int=5)

## choosing regularization

In [None]:
from fourierflow.datastores.navier_stokes_h5 import NavierStokesH5InstDatastore
torch.manual_seed(60)
device = torch.device('cuda')

n_batch = 20
dims = (256, 256)

def plot_heatmap(model, i, loss, error, loss_fn, batch, pred, *args, **kwargs):
    target =  batch['latent'][0, :, :].cpu().numpy()
    est =  model.latent[0, :, :].cpu().numpy()
    err_heatmap = target - est

    fig = heatmap.multi_heatmap(
        [target, est, err_heatmap],
        ["Target", "Estimate", "Error"], *args, **kwargs)
    # plt.savefig(f"paper_ml4ps/inverse_reg_{i}.png",
    #     dpi=300, bbox_inches='tight', pad_inches=0)
    # np.savez_compressed(f"paper_ml4ps/inverse_reg_{i}.npz",
    #     target=target,
    #     est=est,
    #     error=err_heatmap
    # )
    plt.show();
    plt.close("all");


datastore = NavierStokesH5InstDatastore(
    '${FNO_DATA_ROOT}/navier-stokes/grf_forcing_mini_1.h5',
    n_workers=0,
    **{
        'ssr': 1,
        'n_history': 2,
        'n_horizon': 1,
        'batch_size': 20,
        'latent_key': 'f',
        'forcing_key': '',
        'param_key': '',
    }
)
dataloader = datastore.val_dataloader()
batch = next(iter(dataloader))
pp_state_dict = torch.load(
        resolve_path('${SM_MODEL_DIR}/history_matching/adequate_checkpoint/fwd-epoch=19-step=26399-valid_loss=0.00000.ckpt'),
        map_location=device
    )
process_predictor = SimpleBlock2dGeneric(
    **{
        'modes1': 16,
        'width': 24,
        'n_layers': 4,
        'n_history': 2,
        'param': False,
        'forcing': False,
        'latent': True,
    }
)
process_predictor.load_state_dict(
    pp_state_dict
)
model = infer.NaiveLatent(
            process_predictor,
            dims=dims,
            n_batch=n_batch)
model.to(device)
npbatch = {}
for (k,v,) in batch.items():
    npbatch[k] = v.cpu().numpy()
    batch[k] = v.to(device)
optimizer = AdamW(
    model.parameters(),
    lr=0.005,
    weight_decay=0.0)
loss_fn = nn.MSELoss().to(device)
lambdas = [l**2 for l in range(0, 30)]

relerrs = [
    infer.fit(
        batch,
        model,
        loss_fn,
        optimizer,
        n_iter=500,
        check_int=5,
        clip_val=None,
        # callback=plot_heatmap,
        # pen_0=pen_0,
        pen_1=lambda_,
        stop_on_truth=True,
    )[2] for lambda_ in lambdas]


In [None]:
print(repr(list(zip(lambdas, relerrs))))

### longer model

In [None]:

def plot_heatmap(model, i, loss, error, loss_fn, batch, pred, *args, **kwargs):
    target =  batch['latent'][0, :, :].cpu().numpy()
    est =  model.latent[0, :, :].cpu().numpy()
    err_heatmap = target - est

    fig = heatmap.multi_heatmap(
        [target, est, err_heatmap],
        ["Target", "Estimate", "Error"], *args, **kwargs)
    # plt.savefig(f"paper_ml4ps/inverse_reg_{i}.png", dpi=300, bbox_inches='tight', pad_inches=0)
    plt.show();
    plt.close("all");

torch.manual_seed(59)


infer.main(
    fwd_state_dict_path= '${SM_MODEL_DIR}/history_matching/adequate_long_wide/history_matching/*/checkpoints/fwd-*.ckpt',
    fwd_args={
        'modes1': 16,
        'width': 24,
        'n_layers': 4,
        'n_history': 10,
        'param': False,
        'forcing': False,
        'latent': True,
    },
    ds_args={
        'ssr': 1,
        'n_history': 10,
        'n_horizon': 1,
        'batch_size': 20,
        'latent_key': 'f',
        'forcing_key': '',
        'param_key': ''
    },
    callback=plot_heatmap,
    lr=0.05, weight_decay=0.00, n_iter=10, pen_1=30)

In [None]:
lambdas, relerrs = zip(*[(0, 1.6007785766982043), (1, 0.9997765660372001), (4, 0.9007331947795258), (9, 0.8229570081477463), (16, 0.8009129041645531), (25, 0.7878376155161687), (36, 0.7800267568964597), (49, 0.7752918051130337), (64, 0.7707536070233962), (81, 0.7663303353976826), (100, 0.764045919801015), (121, 0.7603965646406099), (144, 0.759816943122361), (169, 0.7553044588047234), (196, 0.7556437640875308), (225, 0.7545405463842202), (256, 0.7592473939734609), (289, 0.769563432528108), (324, 0.7806189493001515), (361, 0.7940838967495987), (400, 0.8070344769974499), (441, 0.8209803543025779), (484, 0.8324510644630306), (529, 0.8433933790498325), (576, 0.8549393999655539), (625, 0.8649656038660587), (676, 0.8746389898996968), (729, 0.8822984327202239), (784, 0.8900433311252656), (841, 0.8969211483359959)])

fig = plt.plot(lambdas, relerrs, )
plt.xlabel(r'$\lambda$')
plt.ylabel("relative error")
plt.yscale("log")
plt.ylim(0.75, 1)
plt.savefig(f"paper_ml4ps/inverse_reg_lambda.png", dpi=300, bbox_inches='tight', pad_inches=0)

In [None]:
arr = model.latent.cpu().detach().numpy()

multi_heatmap([arr[0], arr[1], arr[2]], ["1", "2", "3"])

In [None]:
infer.fit(
    batch,
    model,
    loss_fn,
    optimizer,
    n_iter=1000,
    check_int=5,
    clip_val=None,
    # callback=plot_heatmap,
    # pen_0=pen_0,
    pen_1=30,
)


In [None]:
arr = model.latent.cpu().detach().numpy()
for i in range(5):
    est = arr[i]
    target = npbatch['latent'][i]
    err = est-target
    multi_heatmap([est, target, err], ["est", "target", "Error"])

## Ensemble estimates

In [None]:
datastore = NavierStokesH5InstDatastore(
    '${FNO_DATA_ROOT}/navier-stokes/grf_forcing_mini.h5',
    **{
        'ssr': 1,
        'n_history': 2,
        'n_horizon': 1,
        'batch_size': 20,
        'latent_key': 'f',
        'forcing_key': '',
        'param_key': ''
    }
)

### longer model

In [None]:

def plot_heatmap(model, i, loss, error, loss_fn, batch, pred, *args, **kwargs):
    target =  batch['latent'][0, :, :].cpu().numpy()
    est =  model.latent[0, :, :].cpu().numpy()
    err_heatmap = target - est

    fig = heatmap.multi_heatmap(
        [target, est, err_heatmap],
        ["Target", "Estimate", "Error"], *args, **kwargs)
    # plt.savefig(f"paper_ml4ps/inverse_reg_{i}.png", dpi=300, bbox_inches='tight', pad_inches=0)
    plt.show();
    plt.close("all");

torch.manual_seed(59)


infer.main(
    fwd_state_dict_path= '${SM_MODEL_DIR}/history_matching/adequate_long_wide/history_matching/*/checkpoints/fwd-*.ckpt',
    fwd_args={
        'modes1': 16,
        'width': 24,
        'n_layers': 4,
        'n_history': 10,
        'param': False,
        'forcing': False,
        'latent': True,
    },
    ds_args={
        'ssr': 1,
        'n_history': 10,
        'n_horizon': 1,
        'batch_size': 20,
        'latent_key': 'f',
        'forcing_key': '',
        'param_key': ''
    },
    callback=plot_heatmap,
    lr=0.05, weight_decay=0.00, n_iter=10, pen_1=30)