# Exercise 29 (3) - Data-Driven Solver: train the surrogate

With the generated data and the identified reduced basis for the wave pressures, train a surrogate model as data-driven solver. The neural network architecture can be specified with `selectModel` as either a fully connected or a convolutional neural network. Try to improve the performance by adjusting the hyperparameters.

### Learning goals
- Familiarize yourself with data-driven deep learning training workflows with tools, such as DataSet, DataLoader
- Understand how dimensionality reduction techniques can be combined with a deep learning training

In [None]:
import torch
import pandas as pd
import numpy as np
import time
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

In [None]:
import DataSet

In [None]:
torch.manual_seed(2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Select neural network architecture

In [None]:
selectModel = "FNN"
#selectModel = "CNN"

## Neural network helper functions

**weight initialization, normalization for convolutional layers & transformation from convolutional to fully connected layers**

In [None]:
def init_weights(m):
    """Initialize weights of neural network with xavier initialization."""
    if type(m) == torch.nn.Linear or type(m) == torch.nn.Conv1d:
        torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain('leaky_relu', 0.2))
        m.bias.data.fill_(0.0)


class PixelNorm(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x / torch.sqrt(torch.sum(x ** 2, axis=(2), keepdim=True) / x.shape[2] + 1e-8)


class SqueezeToFNN(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.reshape((len(x), -1))

**fully connected neural network**

In [None]:
class FNN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        modules = []
        modules.append(torch.nn.Linear(100, 100))
        modules.append(torch.nn.ReLU())
        modules.append(torch.nn.Linear(100, 100))
        modules.append(torch.nn.ReLU())
        modules.append(torch.nn.Linear(100, 100))
        modules.append(torch.nn.ReLU())
        modules.append(torch.nn.Linear(100, 100))
        modules.append(torch.nn.ReLU())
        modules.append(torch.nn.Linear(100, 3))

        self.model = torch.nn.Sequential(*modules)
        self.model.apply(init_weights)

    def forward(self, x):
        return self.model(x)

**convolutional neural network**

In [None]:
class CNN(torch.nn.Module):
    def __init__(self):
        super().__init__()

        modules = []
        modules.append(torch.nn.Conv1d(4, 8, kernel_size=3, stride=1, padding=0, device=device))
        modules.append(PixelNorm())
        modules.append(torch.nn.PReLU(init=0.2, device=device))
        modules.append(torch.nn.MaxPool1d(kernel_size=2, stride=2))

        modules.append(torch.nn.Conv1d(8, 16, kernel_size=3, stride=1, padding=0, device=device))
        modules.append(PixelNorm())
        modules.append(torch.nn.PReLU(init=0.2, device=device))
        modules.append(torch.nn.MaxPool1d(kernel_size=2, stride=2))

        modules.append(torch.nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=0, device=device))
        modules.append(PixelNorm())
        modules.append(torch.nn.PReLU(init=0.2, device=device))
        modules.append(torch.nn.MaxPool1d(kernel_size=2, stride=2))

        modules.append(torch.nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=0, device=device))
        modules.append(PixelNorm())
        modules.append(torch.nn.PReLU(init=0.2, device=device))
        modules.append(torch.nn.MaxPool1d(kernel_size=2, stride=2))

        modules.append(SqueezeToFNN())

        modules.append(torch.nn.Linear(164, 100, device=device))
        modules.append(torch.nn.PReLU(init=0.2, device=device))
        modules.append(torch.nn.Linear(100, 100, device=device))
        modules.append(torch.nn.PReLU(init=0.2, device=device))
        modules.append(torch.nn.Linear(100, 3, device=device))

        self.model = torch.nn.Sequential(*modules)
        self.model.apply(init_weights)

    def forward(self, x):
        return self.model(x)

## Pre-processing

**loading settings of measurements**

In [None]:
settings = pd.read_csv("dataset1DFWI/settings.csv")

**load data set**

In [None]:
dataset = DataSet.FullWaveFormInversionDataset1D(settings, device)
datasetTraining, datasetValidation = torch.utils.data.random_split(dataset, [0.9, 0.1],
                                                                   generator=torch.Generator().manual_seed(2))

## Hyperparameter selection, data preperation & model inititialization

In [None]:
if selectModel == "FNN":
    model = FNN()

    SVDBasisU = torch.load("dataset1DFWI/measurementBasis.pt", weights_only=True)

    # hyperparameters
    lr = 1e-2
    batchSize = 256
    alpha = -0.5
    beta = 0.2
    epochs = 400  #300
    clip = 1e-2
    l2 = 1e-6

elif selectModel == "CNN":
    model = CNN()

    # hyperparameters
    lr = 4e-3  #1e-2 #1e-2 #1e-2
    batchSize = 256
    alpha = -0.5
    beta = 0.2
    epochs = 400
    clip = 1e-2
    l2 = 1e-6

print("number of parameters: {:d}".format(
    np.sum(np.array([len(list(model.parameters())[i].flatten()) for i in range(len(list(model.parameters())))]))))

**define dataloader**

In [None]:
dataloaderTraining = DataLoader(datasetTraining, batch_size=batchSize)
dataloaderValidation = DataLoader(datasetValidation, batch_size=len(datasetValidation))

## Optimizer setup

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=l2)

lr_lambda = lambda epoch: (beta * epoch + 1) ** alpha
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

## Training

**training setup**

In [None]:
trainingCostHistory = np.zeros(epochs)
validationCostHistory = np.zeros(epochs)
start = time.perf_counter()
start0 = start

**training loop**

In [None]:
for epoch in range(epochs):
    model.train()
    for batch, sample in enumerate(dataloaderTraining):

        optimizer.zero_grad(set_to_none=True)

        if selectModel == "FNN":
            coeffPred = model((sample[0] @ SVDBasisU.t()).reshape((-1, 100)))
        elif selectModel == "CNN":
            coeffPred = model(sample[0].reshape((-1, 4, settings.N[0] + 1))) 

        cost = 0.5 * torch.mean((coeffPred - sample[2]) ** 2)

        cost.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        scheduler.step()

        trainingCostHistory[epoch] += cost.detach()

    trainingCostHistory[epoch] /= (batch + 1)

    model.eval()
    sample = next(iter(dataloaderValidation))

    if selectModel == "FNN":
        coeffPred = model(
            (sample[0] @ SVDBasisU.t()).reshape((-1, 100))) 
    elif selectModel == "CNN":
        coeffPred = model(sample[0].reshape((-1, 4, settings.N[0] + 1)))

    validationCostHistory[epoch] = 0.5 * torch.mean((coeffPred - sample[2]) ** 2)

    if (epoch % 10 == 0):
        elapsed_time = time.perf_counter() - start
        string = "Epoch: {}/{}\t\tTraining Cost: {:.3E}\t\tValidation Cost: {:.3E}\nElapsed time: {:2f}"
        print(string.format(epoch, epochs - 1, trainingCostHistory[epoch], validationCostHistory[epoch], elapsed_time))
        start = time.perf_counter()

print("Total elapsed training time: {:2f}".format(time.perf_counter() - start0))

**prediction**

In [None]:
model.eval()
sample = next(iter(dataloaderValidation))
if selectModel == "FNN":
    coeffPred = model((sample[0] @ SVDBasisU.t()).reshape((-1, 100))) 
elif selectModel == "CNN":
    coeffPred = model(sample[0].reshape((-1, 4, settings.N[0] + 1))).squeeze() 

## Post-processing

**helper to transform prediction to grid**

In [None]:
def generateMaterialFromCoefficients(coeff, dataset, settings):
    Lx = settings.Lx[0]
    Nx = settings.Nx[0]
    c0 = settings.c0[0]
    x = np.linspace(0, Lx, Nx + 1)

    coeff = DataSet.Denormalize(coeff, dataset.Coeffnorm).detach().numpy()
    c = x * 0 + c0
    c[(x > coeff[0]) & (x < coeff[1])] = coeff[2]

    return c

**grid creation**

In [None]:
Lx = settings.Lx[0]
Nx = settings.Nx[0]
x = np.linspace(0, Lx, Nx + 1)

**prediction visualization**

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(7, 6))

for i in range(9):
    i_ = int(np.floor(i / 3))
    j_ = i % 3

    cpred = generateMaterialFromCoefficients(coeffPred[i], dataset, settings)
    ctrue = generateMaterialFromCoefficients(sample[2][i], dataset, settings)

    ax[i_, j_].plot(x, ctrue, 'k', linewidth=3)
    ax[i_, j_].plot(x, cpred, 'r--', linewidth=3)
    ax[i_, j_].set_xticks([])
    ax[i_, j_].set_yticks([])
    ax[i_, j_].set_ylim([0, settings.c0[0] * 1.1])

ax[0, 0].plot([0], [0], 'k', linewidth=3, label="ground truth")
ax[0, 0].plot([0], [0], 'r--', linewidth=3, label="prediction")

fig.tight_layout()
fig.subplots_adjust(top=0.92)
fig.subplots_adjust(bottom=0.02)
fig.legend(loc='upper center', bbox_to_anchor=(0.5, 1.015), fancybox=True, ncol=2)
plt.show()

**learning history**

In [None]:
fig, ax = plt.subplots()
ax.plot(trainingCostHistory, 'k', label="training")
ax.plot(validationCostHistory, 'r', label="validation")
ax.grid()
ax.set_yscale('log')
ax.legend()
fig.tight_layout()
plt.show()