# Quick tutorial to implicit neural emulators

- We load a dataset (integration step-size dt=900) and a pre-trained implicit model and use it to generate a simulation.
- We have a look at the learned systems of linear equations of the implicit model (defined by tensors M,b).
- We check convergence of the Red-Black Gauss-Seidel solver in a typical forward- and backward-pass.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import pytorch_lightning as pl
from implem.utils import init_torch_device, as_tensor, device, dtype, dtype_np
from implem.model import LightningImplicitModel, LightningImplicitPretrainModel

swe_model = '2D'
assert swe_model in ['1D', '2D']


# data loading / formatting

In [None]:
from src.train import load_data_swe
from implem.data import MultiStepMultiTrialDataset

data, data_descr, n_static_channels, data_scales = load_data_swe(
    filedir='./data', 
    data_fn='dataset_dt900',
    instance_dimensionality=swe_model, normalize_channels=True)

from implem.data import DataModule

offset = np.arange(1) + 1
batch_size = 16

dm = DataModule(data=data, batch_size = batch_size, offset = offset, Dataset = MultiStepMultiTrialDataset)
dm.setup()

for batch in dm.train_dataloader():
    x,y = batch
    print('example input- and output tensor shapes:', (x.shape, y.shape))
    break


# model definition

In [None]:
from implem.model import LightningImplicitModel, ImplicitLayer
import hydra
from omegaconf import read_write

from src.utils import fix_bM_BCs, enforce_bcs, network_init_swe

with hydra.initialize(config_path='configs'):
    cfg = hydra.compose(config_name='defaults.yaml', overrides=[f"data=swe{swe_model}", "model=model_implicit"])

model = hydra.utils.instantiate(
        cfg.model,
        instance_dimensionality = cfg.data.instance_dimensionality,
        input_channels = data.shape[2],
        static_channels = n_static_channels, # do not predict channels that do not change over time
        offset = [1],                        # how many steps into future to predict 
        system_determ=[fix_bM_BCs, None],    # enforce BCs on system of linear equations
        format_output = enforce_bcs,         # enforce BCs on velocities
        _recursive_ = False # necessary for model.configure_optimizers()
)

# initialize model: fix some layers to ensure output of linear solve is water height.
network_init_swe(model, x, data_scales, dt=300., dx=1e4, w_imp=0.5, g=9.81, cd=1e-3, ah=1e3)

t_start = -2
x = as_tensor(data[0, t_start]).unsqueeze(0)
print(model.forward(x).shape)

print(model)

# model fitting

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint
from omegaconf import OmegaConf
from os import listdir


train_in_notebook = False


if train_in_notebook:

    # if the model is trained on GPU add a GPU logger to see GPU utilization in comet-ml logs:
    GPUS = 1
    if GPUS < 1:
        callbacks = None
    else:
        callbacks = [
            pl.callbacks.EarlyStopping(monitor='loss/val', patience=50),
        ]

    logger_tb = pl.loggers.TensorBoardLogger(".", "", "", 
                                             log_graph=True, 
                                             default_hp_metric=False)
    trainer = pl.Trainer(**OmegaConf.to_container(cfg.trainer),
                         logger=logger_tb, 
                         callbacks=callbacks)
    trainer.fit(model, dm)

else:

    # load provided implicit model for integration step-size dt = 900.  
    model_str = 'SWE2D_dt900/ImplicitModel/min_net0/2021-09-20_135046/'

    fls = listdir(f"outputs/{model_str}checkpoints/")
    with hydra.initialize(config_path=f"outputs/{model_str}hydra/"):
        cfg = hydra.compose(config_name='config.yaml')
    print('cfg', cfg)
        
    model = LightningImplicitModel.load_from_checkpoint(
        checkpoint_path=f"outputs/{model_str}checkpoints/" + fls[-1], 
        **cfg['model'], 
        instance_dimensionality = cfg.data.instance_dimensionality,
        input_channels=data.shape[2],
        static_channels = n_static_channels,
        offset = offset,
        system_determ=[fix_bM_BCs, None],
        format_output = enforce_bcs,
    )
    print('model', model)

model = model.to(device)

# model evualtion

In [None]:
# model simulation from selected datapoint in dataset to compare against true trajectory

n, t_start, t_end = -1, 20, 101 # trial, starting step, stopping step

out = []
with torch.no_grad():
    for t in range(t_end-t_start-1):
        if t == 0:
            x = as_tensor(data[n, t_start+t]).unsqueeze(0)
            out.append(x[0,:-n_static_channels].unsqueeze(0))
        else:
            x = torch.cat((out[-1][0], as_tensor(data[n, t_start][-n_static_channels:])), dim=0).unsqueeze(0)

        x_est = model(x)
        out.append(x_est)

out = torch.cat(out, axis=0)

In [None]:
# plot a few snapshots for comparison

from src.utils import plot_results_swe
i = 0
plot_results_swe(data_numerical=data[n,t_start:t_end:,i,:], 
                 data_model=out[:,i,:].cpu().numpy(), 
                 i=i, swe_model=swe_model, if_save=False, fig_path=None)

# compare maps of learned with true stencils for this problem

In [None]:
from src.utils import swe2D_true_bM

dt = 900. # time step [s]
dx = 1e4 # grid spacing [m]
g = 9.81
w_imp = 0.5
cd = 1e-3
ah = 1000.

n, t_start = 50, -2

# true stencils
x = as_tensor(data[n, t_start]).unsqueeze(0)
b_true, M_true, us, vs = swe2D_true_bM(x, dt, dx, g, w_imp, cd, ah, data_scales, comp_u='calculate')

plt.figure(figsize=(16,5))
plt.imshow(b_true[0,0].detach().cpu().numpy().T)
print(b_true[0,0][1:-1,1:-1].detach().cpu().numpy().mean())
plt.title('right-handside b')
plt.colorbar()
plt.show()

for i in range(M_true.shape[2]):
    plt.figure(figsize=(16,5))
    plt.title(r'tensor $\left(M_\phi\right)_{c}$' + f', c = ' + str(i+1))
    plt.imshow(M_true[0,0,i].detach().cpu().numpy().T)
    print(M_true[0,0,i][1:-1,1:-1].detach().cpu().numpy().mean())
    plt.colorbar()
    plt.show()

In [None]:
# learned stencils

x = as_tensor(data[n, t_start]).unsqueeze(0)
model = model.to(device)
M, b, _ = model.impl_layers[0]._forward(x)

plt.figure(figsize=(16,5))
plt.imshow(b[0,0].detach().cpu().numpy().T)
plt.title('right-handside b')
print(b[0,0][1:-1,1:-1].detach().cpu().numpy().mean())
plt.colorbar()
plt.show()

for i in range(5):
    plt.figure(figsize=(16,5))
    plt.imshow(M[0,0,i].detach().cpu().numpy().T)
    plt.title(r'tensor $\left(M_\phi\right)_{c}$' + f', c = ' + str(i+1))
    print(M[0,0,i][1:-1,1:-1].detach().cpu().numpy().mean())
    plt.colorbar()
    plt.show()

# convergence of forward-pass for learned (M,b) and example x

In [None]:
from implem.utils import transpose_compact_blockmat, transpose_compact_blockmat_sep_eqs_per_field
from implem.utils import banded_gauss_seidel_redblack, biCGstab_l

settings = model.impl_layers[0].settings

settings['thresh'] = 1e-15 # for training we use 1e-14
settings['max_iter'] = 50

with torch.no_grad():
    settings['x_init'] = 1.* x[:,:1]
    z, diagnostics = banded_gauss_seidel_redblack(M, b,
                                        **settings)
    settings['x_init'] = None

plt.semilogy(diagnostics[:,0,0].detach().cpu().numpy())
plt.ylabel('MSE on |Az-b|')
plt.xlabel('Red-Black Gauss-Seidel iterations')
plt.show()

# convergence of backward-pass solve for learned (M,b) and example x

In [None]:
from implem.utils import transpose_compact_blockmat, transpose_compact_blockmat_sep_eqs_per_field
from implem.utils import banded_gauss_seidel_redblack, biCGstab_l

x = as_tensor(data[n, t_start]).unsqueeze(0)
z = model.impl_layers[0].forward(x).detach()
z.requires_grad = True
pred = model.impl_layers[1].forward(z)
y = as_tensor(data[n, t_start])[:-n_static_channels].unsqueeze(0)

loss_function = torch.nn.MSELoss()
loss = loss_function(input=pred, target=y)
loss.backward()

dLdz = z.grad[:,:1]

settings = model.impl_layers[0].settings

settings['thresh'] = 1e-25 # for training we use 1e-24
settings['max_iter'] = 50

if settings['sep_eqs_per_field']:
    transpose_M = transpose_compact_blockmat_sep_eqs_per_field
    start_flatten_dim_M = 3 # M.shape = (N, L, K, *spatial_dims)
else:
    transpose_M = transpose_compact_blockmat
    start_flatten_dim_M = 4 # M.shape = (N, L, L, K, *spatial_dims)

MT = transpose_M(M.flatten(start_dim=start_flatten_dim_M),
                         offdiagonals=settings['offdiagonals']).reshape(M.shape)

with torch.no_grad():
    settings['x_init'] = torch.zeros_like(dLdz)
    settings['x_init'] = torch.nn.functional.pad(input = settings['x_init'], 
                                                         pad = settings['pad_x_backward_init'])
    z, diagnostics = banded_gauss_seidel_redblack(MT, dLdz,
                                **settings)
    settings['x_init'] = None

plt.semilogy(diagnostics[:,0,0].detach().cpu().numpy())
plt.ylabel('MSE on |Az-b|')
plt.xlabel('Red-Black Gauss-Seidel iterations')
plt.show()