# Exercise 32 - Physics-Informed Neural Networks for Inverse Problems
### Task
Compare the effect of a linear, a fully connected neural network, and a convolutional neural network ansatz on the inversion quality of a physics-informed neural network for full waveform inversion. The ansatz is defined via `selectModel`. If necessary, adjust the number of epochs. 

### Learning goals
- Familiarize yourself with the syntax of the physics-informed neural network for full domain full waveform inversion
- Gain intuition about the three ansatz formulations for the material distribution

In [None]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import time

In [None]:
torch.set_default_dtype(torch.float64)
device = torch.device('cpu')
torch.manual_seed(2)

## Select material distribution ansatz

In [None]:
#selectModel = "Linear" 
selectModel = "FNN"
#selectModel = "CNN"

## Ansatz helper functions

**weight initialization and normalization for convolutional layers**

In [None]:
def init_weights(m):
    """Initialize weights of neural network with xavier initialization."""
    if type(m) == torch.nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain('leaky_relu', 0.2))
        m.bias.data.fill_(0.0)


class PixelNorm(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x / torch.sqrt(torch.sum(x ** 2, axis=(2), keepdim=True) / x.shape[2] + 1e-8)

**linear ansatz**

In [None]:
class LinearAnsatz(torch.nn.Module):
    def __init__(self, Nx, device, init=1.):
        super().__init__()
        self.coefficients = torch.nn.Parameter(torch.ones((1, 1, Nx + 3), device=device) * init)

    def forward(self, dummy):
        return self.coefficients

**fully connected neural network ansatz**

In [None]:
class FNN(torch.nn.Module):
    def __init__(self, input_dimension, hidden_dimension, output_dimension):
        super().__init__()

        modules = []
        modules.append(torch.nn.Linear(input_dimension, hidden_dimension[0]))
        modules.append(torch.nn.PReLU(init=0.2))
        for i in range(len(hidden_dimension) - 1):
            modules.append(torch.nn.Linear(hidden_dimension[i], hidden_dimension[i + 1]))
            modules.append(torch.nn.PReLU(init=0.2))

        modules.append(torch.nn.Linear(hidden_dimension[-1], output_dimension))

        # Scale output between 0 and 1 with Sigmoid
        modules.append(torch.nn.Sigmoid())

        self.model = torch.nn.Sequential(*modules)
        self.model.apply(init_weights)

    def forward(self, x):
        return self.model(x).squeeze().unsqueeze(0).unsqueeze(0)

**convolutional neural network ansatz**

In [None]:
class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        modules = []
        modules.append(torch.nn.Conv1d(128, 64, kernel_size=3, padding=1, stride=1))
        modules.append(torch.nn.PReLU(init=0.2))
        modules.append(PixelNorm())
        modules.append(torch.nn.Upsample(scale_factor=2, mode='nearest'))

        modules.append(torch.nn.Conv1d(64, 32, kernel_size=3, padding=1, stride=1))
        modules.append(torch.nn.PReLU(init=0.2))
        modules.append(PixelNorm())
        modules.append(torch.nn.Upsample(scale_factor=2, mode='nearest'))

        modules.append(torch.nn.Conv1d(32, 16, kernel_size=3, padding=1, stride=1))
        modules.append(torch.nn.PReLU(init=0.2))
        modules.append(PixelNorm())
        modules.append(torch.nn.Upsample(scale_factor=2, mode='nearest'))

        modules.append(torch.nn.Conv1d(16, 1, kernel_size=3, padding=1, stride=1))

        # Scale output between 0 and 1 with Sigmoid
        modules.append(torch.nn.Sigmoid())

        self.model = torch.nn.Sequential(*modules)
        self.model.apply(init_weights)

    def forward(self, x):
        return self.model(x)

## Physics-informed residual

In [None]:
def getResidual(cpred, um, fm, i, dx, dt):
    upred = um[i]  # should also be a NN in case of partial domain knowledge
    f = fm[i]

    utt = (upred[:, 2:] - 2 * upred[:, 1:-1] + upred[:, :-2])

    c2uxx = (dt / dx) ** 2 * ((0.5 / cpred[1:-1] ** 2 + 0.5 / cpred[2:] ** 2) ** (-1) * (upred[2:] - upred[1:-1]) - \
                              (0.5 / cpred[:-2] ** 2 + 0.5 / cpred[1:-1] ** 2) ** (-1) * (upred[1:-1] - upred[:-2]))

    return (utt[1:-1] - c2uxx[:, 1:-1] - f[1:-1, :-1] * dt ** 2) ** 2

## Pre-processing

**loading settings of measurement**

In [None]:
settings = pd.read_csv("measurement1DFWI/settings.csv")

Lx = settings.Lx[0]
Nx = settings.Nx[0]
dx = Lx / Nx
dt = settings.dt[0]
N = settings.N[0]
c0 = settings.c0[0]

**grid creation**

In [None]:
x = torch.linspace(0 - dx, Lx + dx, Nx + 3, device=device)  # with ghost cells
t = torch.linspace(0, (N - 1) * dt, N, device=device)
x_, t_ = torch.meshgrid(x, t, indexing='ij')

**loading measurements**

In [None]:
numberOfSources = 2
fm = torch.zeros((numberOfSources, Nx + 1, N))
um = torch.zeros((numberOfSources, Nx + 1, N + 1))
for i in range(numberOfSources):
    fm[i] = torch.tensor(pd.read_hdf("measurement1DFWI/source" + str(i) + ".h5").values, device=device)
    um[i] = torch.tensor(pd.read_hdf("measurement1DFWI/signal" + str(i) + ".h5").values, device=device)
cm = torch.tensor(pd.read_hdf("measurement1DFWI/material.h5").values, device=device)[:, 0]

## Hyperparameter selection & model/ansatz initialization

In [None]:
if selectModel == "Linear":
    model = LinearAnsatz(Nx, device)
    modelInput = x.unsqueeze(1)  # dummy
    print(x.shape)

    # hyperparameters
    lr = 1e-2
    alpha = -0.5
    beta = 0.2
    epochs = 1000
    costScaling = 1e8
    clip = 1e-4
    weightLrFactor = 10

elif selectModel == "FNN":
    model = FNN(1, [100, 100, 100], 1)
    modelInput = x.unsqueeze(1)
    modelInput = (modelInput - torch.min(modelInput)) / (
            torch.max(modelInput) - torch.min(modelInput)) * 2 - 1  # normalize and center input data

    lr = 2e-3
    alpha = -0.2
    beta = 0.2
    epochs = 1000  # 10000
    clip = 1e-3
    weightLrFactor = 10

elif selectModel == "CNN":
    model = CNN()
    modelInput = torch.randn((1, 128, 15), device=device)
    modelInput = (modelInput - torch.min(modelInput)) / (
            torch.max(modelInput) - torch.min(modelInput)) * 2 - 1  # normalize and center input data

    # hyperparameters
    lr = 1e-2
    alpha = -0.2
    beta = 0.5
    epochs = 1000
    costScaling = 1e8
    clip = 1e-3
    weightLrFactor = 10

print("number of parameters: {:d}".format(
    np.sum(np.array([len(list(model.parameters())[i].flatten()) for i in range(len(list(model.parameters())))]))))

## Optimizer setup

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr)
weights = torch.ones((Nx - 1, N - 1), requires_grad=True, dtype=torch.float, device=device)
optimizer.add_param_group({'params': weights})
optimizer.param_groups[-1]['lr'] = lr * weightLrFactor

lr_lambda = lambda epoch: (beta * epoch + 1) ** alpha
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

## Training

**training setup**

In [None]:
costHistory = np.zeros(epochs)
start = time.perf_counter()
start0 = start
model.train()

**training loop**

In [None]:
for epoch in range(epochs):
    optimizer.zero_grad(set_to_none=True)

    cpred = model(modelInput)[0, 0, 1:-1].unsqueeze(1) * c0

    residual = torch.zeros((Nx - 1, N - 1))
    for i in range(2):
        residual += getResidual(cpred, um, fm, i, dx, dt)

    cost = torch.sum(weights * residual)
    costUnweighted = torch.sum(residual.detach())
    cost.backward()
    weights.grad *= -1

    torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
    optimizer.step()
    scheduler.step()
    optimizer.param_groups[-1]['lr'] = optimizer.param_groups[0][
                                           'lr'] * weightLrFactor  # countering of how scheduler treats all learning rates in the same manner

    costHistory[epoch] = costUnweighted

    if (epoch % 100 == 0):
        elapsed_time = time.perf_counter() - start
        string = "Epoch: {}/{}\t\tCost function: {:.3E}\t\tElapsed time: {:2f}"
        print(string.format(epoch, epochs - 1, costHistory[epoch], elapsed_time))
        start = time.perf_counter()

print("Total elapsed training time: {:2f}".format(time.perf_counter() - start0))

**prediction of material distribution**

In [None]:
model.eval()
cpred = model(modelInput)[0, 0, 1:-1].detach() * c0

## Post-processing

**predicted material distribution & true material distribution**

In [None]:
fig, ax = plt.subplots()
ax.plot(x[1:-1], cm[1:-1], 'gray')
ax.plot(x[1:-1], cpred, 'k')
plt.show()

**learning history**

In [None]:
fig, ax = plt.subplots()
ax.plot(costHistory, 'k')
ax.set_yscale('log')
plt.show()

**spatio-temporal residual distribution**

In [None]:
frequency = settings.frequency[0]

fig, ax = plt.subplots(figsize=(6, 6))
cp = ax.pcolormesh(x_[2:-2, :-1] / Lx, t_[2:-2, :-1] * frequency, residual.detach() + 1e-80, cmap=plt.cm.jet,
                   norm=matplotlib.colors.LogNorm(), shading='auto')
ax.set_xlabel('$x / L_x$ [-]')
ax.set_ylabel('$t f$ [-]')
fig.colorbar(cp)
ax.set_title("residual in $u$")
fig.tight_layout()
plt.show()

**spatio-temporal weighting distribution**

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
cp = ax.pcolormesh(x_[2:-2, :-1] / Lx, t_[2:-2, :-1] * frequency, weights.detach(), cmap=plt.cm.jet, shading='auto')
ax.set_xlabel('$x / L_x$ [-]')
ax.set_ylabel('$t f$ [-]')
fig.colorbar(cp)
ax.set_title("final weights $\\kappa$")
fig.tight_layout()
plt.show()