# Task 1: Feed-forward neural networks


---

Tutorial: "Machine Learning in Solid Mechanics @ Cyber-Physical Simulation, TU Darmstadt

Lecturer: Prof. Oliver Weeger

Assistants: Dominik K. Klein, Jasper O. Schommartz

---





*Run the following cell to clone the GitHub repository in your current Google Colab environment and install the local package. For the changes to take effect you might need to restart your Colabl session first ("Runtime / Restart session" in the header menu).*

In [None]:
!git clone --depth 1 https://github.com/CPShub/TutorialMLinSolidMechanics.git
!cd TutorialMLinSolidMechanics/ffnn_introduction && pip install -e .

*Run the following cell to import all modules and python files to this notebook. If you made changes in the python files, run the following cell again to update the python files in this notebook. You might need to restart your Colab session first ("Runtime / Restart session" in the header menu).*

In [None]:
import datetime
import importlib

import jax
import jax.random as jrandom
import klax
from matplotlib import pyplot as plt
import time

import tmlsm.losses as tl
import tmlsm.data as td
import tmlsm.models as tm

importlib.reload(tm)
importlib.reload(tl)
importlib.reload(td)

now = datetime.datetime.now

*If you want to clone the repository again, you have to delete it from your Google Colab files first. For this, you can run the following cell.*

In [None]:
%rm -rf TutorialMLinSolidMechanics

### Load data and model

In [None]:
# Create random key for reproducible weight initialization, and
# batch splits. The call to `time.time_ns()` may be replaced with
# a constant seed if exactly reproductible results ought to be
# produced.
key = jrandom.PRNGKey(time.time_ns())
keys = jrandom.split(key, 2)

# Build model instance
model = tm.build(key=keys[0])

# Load data
x, y, x_cal, y_cal = td.bathtub()

print(model)

In [None]:
# Calibrate model
t1 = now()
print(t1)

model, history = klax.fit(
    model,
    (x_cal, y_cal),
    batch_size=32,
    steps=1_000,
    loss_fn=tl.MSE(),
    history=klax.HistoryCallback(log_every=1),
    key=keys[1],
)

t2 = now()
print("it took", t2 - t1, "(sec) to calibrate the model")

history.plot()

### Model evaluation

In [None]:
# First the model need to be finalized to unwrap and apply all
# wrappers and constraints (if present).
model_ = klax.finalize(model)

plt.figure(2)
plt.scatter(x_cal[::10], y_cal[::10], c="green", label="calibration data")
plt.plot(x, y, c="black", linestyle="--", label="bathtub function")
plt.plot(x, jax.vmap(model_)(x), label="model", color="red")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.show()