# Exercise 33 - Iterative Forward Solver
### Task
Compare the effect of a linear, a fully connected neural network, and a convolutional neural network ansatz on the inversion quality of a physics-informed neural network for full waveform inversion. The ansatz is defined via `selectModel`. If necessary, adjust the number of epochs. 

### Learning goals
- Familiarize yourself with the syntax of the iterative forward solver for full waveform inversion
- Gain intuition about the three ansatz formulations for the material distribution

In [None]:
import torch
import pandas as pd
import numpy as np
import time
import matplotlib.pyplot as plt

In [None]:
import FiniteDifference

In [None]:
torch.set_default_dtype(torch.float64)
device = torch.device('cpu')
torch.manual_seed(11)

## Select material distribution ansatz

In [None]:
selectModel = "Linear"
#selectModel = "FNN"
#selectModel = "CNN"

## Ansatz helper functions

**weight initialization and normalization for convolutional layers**

In [None]:
def init_weights(m):
    """Initialize weights of neural network with xavier initialization."""
    if type(m) == torch.nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain('leaky_relu', 0.2))
        m.bias.data.fill_(0.0)


class PixelNorm(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x / torch.sqrt(torch.sum(x ** 2, axis=(2), keepdim=True) / x.shape[2] + 1e-8)

**linear ansatz**

In [None]:
class LinearAnsatz(torch.nn.Module):
    def __init__(self, Nx, device, init=1.):
        super().__init__()
        self.coefficients = torch.nn.Parameter(torch.ones((1, 1, Nx + 3), device=device) * init)

    def forward(self, dummy):
        return self.coefficients

**fully connected neural network ansatz**

In [None]:
class FNN(torch.nn.Module):
    def __init__(self, input_dimension, hidden_dimension, output_dimension):
        super().__init__()

        modules = []
        modules.append(torch.nn.Linear(input_dimension, hidden_dimension[0]))
        modules.append(torch.nn.LeakyReLU(inplace=True))
        for i in range(len(hidden_dimension) - 1):
            modules.append(torch.nn.Linear(hidden_dimension[i], hidden_dimension[i + 1]))
            modules.append(torch.nn.PReLU(init=0.2))

        modules.append(torch.nn.Linear(hidden_dimension[-1], output_dimension))

        # Scale output between 0 and 1 with Sigmoid
        modules.append(torch.nn.Sigmoid())

        self.model = torch.nn.Sequential(*modules)
        self.model.apply(init_weights)

    def forward(self, x):
        return self.model(x).squeeze().unsqueeze(0).unsqueeze(0)

**convolutional neural network ansatz**

In [None]:
class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        modules = []
        modules.append(torch.nn.Conv1d(128, 64, kernel_size=3, padding=1, stride=1))
        modules.append(PixelNorm())
        modules.append(torch.nn.PReLU(init=0.2))
        modules.append(torch.nn.Upsample(scale_factor=2, mode='nearest'))

        modules.append(torch.nn.Conv1d(64, 32, kernel_size=3, padding=1, stride=1))
        modules.append(PixelNorm())
        modules.append(torch.nn.PReLU(init=0.2))
        modules.append(torch.nn.Upsample(scale_factor=2, mode='nearest'))

        modules.append(torch.nn.Conv1d(32, 16, kernel_size=3, padding=1, stride=1))
        modules.append(torch.nn.PReLU(init=0.2))
        modules.append(PixelNorm())

        modules.append(torch.nn.Upsample(scale_factor=2, mode='nearest'))

        modules.append(torch.nn.Conv1d(16, 1, kernel_size=3, padding=1, stride=1))

        # Scale output between 0 and 1 with Sigmoid
        modules.append(torch.nn.Sigmoid())

        self.model = torch.nn.Sequential(*modules)
        self.model.apply(init_weights)

    def forward(self, x):
        return self.model(x)

## Pre-processing

**loading settings of measurement**

In [None]:
settings = pd.read_csv("measurement1DFWI/settings.csv")

Lx = settings.Lx[0]
Nx = settings.Nx[0]
dx = Lx / Nx
dt = settings.dt[0]
N = settings.N[0]
c0 = settings.c0[0]

**grid creation**

In [None]:
x = np.linspace(0 - dx, Lx + dx, Nx + 3)  # with ghost cells
t = np.linspace(0, (N - 1) * dt, N)
x_, t_ = np.meshgrid(x, t, indexing='ij')

**loading measurements**

In [None]:
numberOfSources = 2
fm = np.zeros((numberOfSources, Nx + 1, N))
um = np.zeros((numberOfSources, Nx + 1, N + 1))
for i in range(numberOfSources):
    fm[i] = np.array(pd.read_hdf("measurement1DFWI/source" + str(i) + ".h5").values)
    um[i] = np.array(pd.read_hdf("measurement1DFWI/signal" + str(i) + ".h5").values)
cm = np.array(pd.read_hdf("measurement1DFWI/material.h5").values)[:, 0]

sensorPositions = (0, -1)

**initial conditions**

In [None]:
u0 = x * 0
u1 = x * 0

## Hyperparameter selection & model/ansatz initialization

In [None]:
# model definition
if selectModel == "Linear":
    model = LinearAnsatz(Nx, device)
    modelInput = torch.from_numpy(x).unsqueeze(1)  # dummy

    # hyperparameters
    lr = 2e-2
    alpha = -0.5
    beta = 0.2
    epochs = 3000
    costScaling = 1e8
    clip = 1e-2

elif selectModel == "FNN":
    model = FNN(1, [100, 100], 1)
    modelInput = torch.from_numpy(x).unsqueeze(1)
    modelInput = (modelInput - torch.min(modelInput)) / (
                torch.max(modelInput) - torch.min(modelInput)) * 2 - 1  # normalize and center input data

    # hyperparameters
    lr = 1e-2
    alpha = -0.5
    beta = 0.4  # IMPORTANT PARAMETER  
    epochs = 3000
    costScaling = 1e8
    clip = 1e-2  #1e-2 #1e-3

elif selectModel == "CNN":
    model = CNN()
    modelInput = torch.randn((1, 128, 15), device=device)
    modelInput = (modelInput - torch.min(modelInput)) / (
                torch.max(modelInput) - torch.min(modelInput)) * 2 - 1  # normalize and center input data

    # hyperparameters
    lr = 1e-2  #2e-3 #2e-3 #2e-2 #5e-3 #1e-2
    alpha = -0.5  #-0.2
    beta = 0.8  #0.5  
    epochs = 3000
    costScaling = 1e8
    clip = 1e-3  #1e-3

print("number of parameters: {:d}".format(
    np.sum(np.array([len(list(model.parameters())[i].flatten()) for i in range(len(list(model.parameters())))]))))

## Optimizer setup

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr)

lr_lambda = lambda epoch: (beta * epoch + 1) ** alpha
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

## Training

**training setup**

In [None]:
costHistory = np.zeros(epochs)
start = time.perf_counter()
start0 = start
model.train()

**training loop**

In [None]:
for epoch in range(epochs):
    optimizer.zero_grad(set_to_none=True)

    cpred = c0 * model(modelInput)
    cpred[:, :, :1] = c0  # assuming boundary values to be intact 
    cpred[:, :, -1:] = c0  # assuming boundary values to be intact
    cpred.grad = torch.zeros_like(cpred, device=device)

    gradient, cost = FiniteDifference.getAllAdjointSensitivities(u0, u1, fm, cpred[0, 0].detach().numpy(),
                                                                 dx, Nx, dt, N, um, sensorPositions)

    cpred.grad[0, 0, 1:-1] = torch.from_numpy(gradient)

    cpred.backward(
        costScaling * cpred.grad)  # explanation: https://web.archive.org/web/20221026061918/https://medium.com/@monadsblog/pytorch-backward-function-e5e2b7e60140

    torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
    optimizer.step()
    scheduler.step()

    costHistory[epoch] = cost

    if selectModel == "Linear":
        model.coefficients.data = model.coefficients.data.clamp(0., 1.)  # clamping instead of sigmoid

    if (epoch % 100 == 0):
        elapsed_time = time.perf_counter() - start
        string = "Epoch: {}/{}\t\tCost function: {:.3E}\t\tElapsed time: {:2f}"
        print(string.format(epoch, epochs - 1, costHistory[epoch], elapsed_time))
        start = time.perf_counter()

print("Total elapsed training time: {:2f}".format(time.perf_counter() - start0))

**prediction of material distribution**

In [None]:
model.eval()
cpred = c0 * model(modelInput).squeeze().detach().cpu()

## Post-processing

**predicted material distribution & true material distribution**

In [None]:
fig, ax = plt.subplots()
ax.plot(x[1:-1], cpred[1:-1], 'k--', label="prediction")
ax.plot(x[1:-1], cm[1:-1], 'r', label="ground truth")
ax.grid()
ax.legend()
ax.set_xlabel("$x$")
ax.set_ylabel("$c(x)$")
fig.tight_layout()
plt.show()

**learning history**

In [None]:
fig, ax = plt.subplots()
ax.plot(costHistory, 'k')
ax.grid()
ax.set_yscale('log')
ax.set_xlabel("epoch")
ax.set_ylabel("cost")
fig.tight_layout()
plt.show()