# WI4450 - Special Topics in Computational Science and Engineering
# Alexander Heinlein
#
# Decomposing Neural Networks

## Motivation

In this lecture, we will discuss approaches for **decomposing neural networks**. In the context of neural networks, there are different reasons to decompose neural networks. We will focus on:

+ **Parallelization and efficiency** <br> and
+ **Robustness and convergence**

Moreover, there are approaches which take a **given neural network and decompose the application of that given network** and approaches which **develop a decomposed neural network archiecture** itself. We will briefly discuss the first approach and then mostly focus on the second approach and link it to approaches from numerical analysis and scientific computing.

## Parallelization Concepts for Neural Networks

Firstly, let us consider the some of the parallelization concepts discussed in:

> Fedus, William, Barret Zoph, and Noam Shazeer. "Switch transformers: Scaling to trillion parameter models with simple and efficient sparsity." Journal of Machine Learning Research 23.120 (2022): 1-39.



The following graphic is taken from that article and shows the **different concepts for parallelization of neural networks**:

<img src="fig/parallelization.png" width="600">


We can see several data weight and partitioning strategies, the first three of which we will discuss:

+ **Data parallelism**
+ **Model parallelism**
+ **Model and data parallelism**

### Data Parallelism

In the data parallelism approach, the **data set is split into different parts** and the same model is applied to each part. This is the **most common and most simple approach** for parallelizing neural networks. 
+ In **distributed-memory contexts**, this can be implemented using different devices; this goes beyond the scope of this lecture. 
+ In **shared-memory contexts**, this can be implemented using threads. Using Jax, this can easily be implemented using the `vmap` function. 

This means that we do not parallelize the computation of a single application of the model but parallelize the sum in the **mean squared error (MSE) loss function**
$$
    \min_{\theta} \underbrace{\frac{1}{n} \sum_{i=1}^{n} \left( \mathcal{N}_\theta (x_i,t_i) - y_i \right)^2}_{=: \mathcal{L}(\theta)}.
$$

In particular, we split the $n$ indices into $P$ parts $I_1, \ldots, I_P$ and compute the sum as
$$
    \mathcal{L}(\theta) = \sum_{p=1}^{P} \underbrace{\sum_{i \in I_p} \left( \mathcal{N}_\theta (x_i,t_i) - y_i \right)^2}_{\mathcal{L}_p(\theta)},
$$
where each part can be done completely in parallel:
$$
    \sum_{i \in I_p} \left( \mathcal{N}_\theta (x_i,t_i) - y_i \right)^2
$$

Only for computing the final sum, we need to **synchronize/communicate**.

Then, in each step of **gradient descent**,
$$
    \theta_{k+1} = \theta_k - \alpha \nabla \mathcal{L}(\theta_k),
$$
we also have to compute the gradient of the loss function $\mathcal{L}(\theta)$ in parallel, which is simple due to the linearity of the gradient:
$$
    \nabla \mathcal{L}(\theta) = \sum_{p=1}^{P} \nabla \mathcal{L}_p(\theta).
$$

Then, the gradients computed in parallel have to be summed up to update the parameters $\theta$. In shared-memory contexts, this again requires synchronization, but in distributed-memory contexts, this requires the **communcation of the gradients between the different devices**.

We will now investigate this using a Python example using Jax.

First, we consider again our the `FeedForwardNN` class from the previous lecture:

In [None]:
import jax
from jax import random, vmap, jit
import jax.numpy as jnp

class FeedForwardNN:
    def __init__(self, layer_sizes, key, activation_fn=jax.nn.tanh):
        self.layer_sizes = layer_sizes
        self.activation_fn = activation_fn
        self.params = self.initialize_params(layer_sizes, key)

    def initialize_params(self, layer_sizes, key):
        params = []
        keys = random.split(key, len(layer_sizes) - 1)
        for i in range(len(layer_sizes) - 1):
            W_key, b_key = random.split(keys[i])
            # Initialize weights with normal distribution and scale
            W = random.uniform(W_key, (layer_sizes[i], layer_sizes[i+1]), minval=-1.0, maxval=1.0)
            # Initialize biases with zeros
            b = jnp.zeros(layer_sizes[i+1])
            params.append((W, b))
        return params

    def forward(self, params, x):
        for W, b in params[:-1]:
            # Linear transformation
            x = jnp.dot(x, W) + b
            # Apply activation function
            x = self.activation_fn(x)
        # Output layer (no activation function)
        W, b = params[-1]
        x = jnp.dot(x, W) + b
        return x

    def predict(self, x):
        # Predict output for input x
        return vmap(self.forward, in_axes=(None, 0))(self.params, x)

Now, we will train this network to fit a simple $\sin$ function using many data points:

In [None]:
import matplotlib.pyplot as plt

# Generate sample data
x = jnp.linspace(-1.0, 1.0, 1000)
y = jnp.sin(jnp.pi*x)

# Plot the results
plt.plot(x, y, label='Sine function')
plt.gcf().set_size_inches(4, 2) 
plt.legend()
plt.show()

Let us first implement a non-parallelized version of the network.

In [None]:
import optax
from tqdm import trange

# Define the neural network
key = random.PRNGKey(0)
layer_sizes = [1, 20, 1] 
nn = FeedForwardNN(layer_sizes, key)

# Create a learning rate schedule
learning_rate_schedule = optax.piecewise_constant_schedule(
    init_value=0.1,
)

# Initialize the Adam optimizer with the learning rate schedule
optimizer = optax.adam(learning_rate=learning_rate_schedule)
opt_state = optimizer.init(nn.params)

# Define the loss function (Mean Squared Error)
squared_error = lambda params, x, y: (nn.forward(params, x) - y) ** 2
def loss_fn(params, x, y):
    squared_errors = jnp.array([squared_error(params, xi, yi) for xi, yi in zip(x, y)])
    return jnp.mean(squared_errors)

# Define training step with Adam optimizer
def train_step(params, opt_state, x, y):    
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

Now, we train the network:

In [None]:
# Training loop
losses = []
max_iterations = 10

# Training loop with tqdm progress bar
pbar = trange(max_iterations, desc="Training", leave=True)
for epoch in pbar:
    nn.params, opt_state, current_loss = train_step(nn.params, opt_state, x, y)
    losses.append(current_loss)
    pbar.set_postfix(loss=current_loss)

This is very slow. Finally, we plot the result:

In [None]:
# Create a figure with two subplots
fig, axs = plt.subplots(1, 2, figsize=(10, 2))

# Plot the loss over the number of iterations
axs[0].set_yscale('log')
axs[0].plot(range(max_iterations), losses, label='Loss')
axs[0].set_xlabel('Iteration')
axs[0].set_ylabel('Loss')
axs[0].legend()

# Plot the results
axs[1].plot(x, y, label='Sine function')
axs[1].plot(x, nn.predict(x).reshape(y.shape), label='Neural network')
axs[1].legend()

# Show the plots
plt.show()

The iteration is very slow because each data sample is processed **sequentially in a for loop**:

```python
squared_errors = jnp.array([squared_error(params, xi, yi) for xi, yi in zip(x, y)])
```

Now, we will use **data parallelism** to speed up the training. We now **parallelize over the data samples** using the `vmap` function from Jax:

```python
squared_errors = vmap(squared_error, in_axes=(None, 0, 0))(params, x, y)
```

In [None]:
import optax
from jax import vmap
from tqdm import trange

# Define the neural network
key = random.PRNGKey(0)
layer_sizes = [1, 20, 1] 
nn = FeedForwardNN(layer_sizes, key)

# Create a learning rate schedule
learning_rate_schedule = optax.piecewise_constant_schedule(
    init_value=0.1
)

# Initialize the Adam optimizer with the learning rate schedule
optimizer = optax.adam(learning_rate=learning_rate_schedule)
opt_state = optimizer.init(nn.params)

# Define the loss function (Mean Squared Error)
squared_error = lambda params, x, y: (nn.forward(params, x) - y) ** 2
def loss_fn(params, x, y):
    squared_errors = vmap(squared_error, in_axes=(None, 0, 0))(params, x, y)
    return jnp.mean(squared_errors)

# Define training step with Adam optimizer
def train_step(params, opt_state, x, y):    
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

We again train for the same number of epochs:

In [None]:
# Training loop
losses = []
max_iterations = 10

# Training loop with tqdm progress bar
pbar = trange(max_iterations, desc="Training", leave=True)
for epoch in pbar:
    nn.params, opt_state, current_loss = train_step(nn.params, opt_state, x, y)
    losses.append(current_loss)
    if epoch % 100 == 0:
        pbar.set_postfix(loss=current_loss)

And we plot to obtain the same result:

In [None]:
# Create a figure with two subplots
fig, axs = plt.subplots(1, 2, figsize=(10, 2))

# Plot the loss over the number of iterations
axs[0].set_yscale('log')
axs[0].plot(range(max_iterations), losses, label='Loss')
axs[0].set_xlabel('Iteration')
axs[0].set_ylabel('Loss')
axs[0].legend()

# Plot the results
axs[1].plot(x, y, label='Sine function')
axs[1].plot(x, nn.predict(x).reshape(y.shape), label='Neural network')
axs[1].legend()

# Show the plots
plt.show()

We have obtained the same result, in a fraction of the time. We can easily **run the same training for many more epochs in short time**:

In [None]:
import optax
from jax import vmap
from tqdm import trange

# Define the neural network
key = random.PRNGKey(0)
layer_sizes = [1, 20, 1] 
nn = FeedForwardNN(layer_sizes, key)

# Create a learning rate schedule
learning_rate_schedule = optax.piecewise_constant_schedule(
    init_value=0.1,
    boundaries_and_scales={300: 0.5, 600: 0.5}
)

# Initialize the Adam optimizer with the learning rate schedule
optimizer = optax.adam(learning_rate=learning_rate_schedule)
opt_state = optimizer.init(nn.params)

# Define the loss function (Mean Squared Error)
squared_error = jit(lambda params, x, y: (nn.forward(params, x) - y) ** 2)
@jit
def loss_fn(params, x, y):
    squared_errors = vmap(squared_error, in_axes=(None, 0, 0))(params, x, y)
    return jnp.mean(squared_errors)

# Define training step with Adam optimizer
def train_step(params, opt_state, x, y):    
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

# Training loop
losses = []
max_iterations = 1000

# Training loop with tqdm progress bar
pbar = trange(max_iterations, desc="Training", leave=True)
for epoch in pbar:
    nn.params, opt_state, current_loss = train_step(nn.params, opt_state, x, y)
    losses.append(current_loss)
    if epoch % 100 == 0:
        pbar.set_postfix(loss=current_loss)

Now, we obatain a much better fit:

In [None]:
# Create a figure with two subplots
fig, axs = plt.subplots(1, 2, figsize=(10, 2))

# Plot the loss over the number of iterations
axs[0].set_yscale('log')
axs[0].plot(range(max_iterations), losses, label='Loss')
axs[0].set_xlabel('Iteration')
axs[0].set_ylabel('Loss')
axs[0].legend()

# Plot the results
axs[1].plot(x, y, label='Sine function')
axs[1].plot(x, nn.predict(x).reshape(y.shape), label='Neural network')
axs[1].legend()

# Show the plots
plt.show()

### Model Parallelism

In the model parallelism approach, the **model itself is split into different parts** and each part is applied to the data. This can be useful if the **model is too large to fit on a single device**. The main disadvantage is that the model has to be split into different parts, which can lead to **communication overhead in the case of distributed-memory parallelism**.

For a simple feedforward neural network, parallelization of the model means **splitting up the weights and biases of the network** and then **parallelizing the application of the neural network**. To discuss the concept, let us first consider on the case of a neural network with a single hidden layer; the case of multiple hiddent layers can be discussed analogously.

Consider the neural network
$$
    \mathcal{N}_\theta (x) = A^\top \sigma \big( W x + b \big),
$$
where $W \in \mathbb{R}^{m \times d}$, $b \in \mathbb{R}^m$, and $A \in \mathbb{R}^m$,. We can split the weights and biases into $P$ parts $W_1, \ldots, W_P$ and $b_1, \ldots, b_P$ as follows
$$
    W
    =
    \begin{bmatrix}
        W_1 \\
        \vdots \\
        W_P
    \end{bmatrix},
    \quad
    b
    =
    \begin{bmatrix}
        b_1 \\
        \vdots \\
        b_P
    \end{bmatrix},
    \quad \text{and} \quad
    A
    =
    \begin{bmatrix}
        A_1 \\
        \vdots \\
        A_P
    \end{bmatrix}.
$$


Then, we can write the application of the neural network as follows
$$
    \mathcal{N}_\theta (x)
    =
    \begin{bmatrix}
        A_1^\top & \cdots & A_P^\top
    \end{bmatrix}
    \cdot
    \begin{bmatrix}
        \sigma \big( W_1 x + b_1 \big) \\
        \vdots \\
        \sigma \big( W_P x + b_P \big)
    \end{bmatrix}
    =
    \sum_{p=1}^{P} \underbrace{A_p^\top \sigma \big( W_p x + b_p \big)}_{=: \mathcal{N}_{p,\theta_p} (x)}.
$$

**Each part of the sum can be computed independently**, the application of a single layer can be written as the sum of individual models $\mathcal{N}_{p,\theta_p} (x)$. This can be implemented using the `pmap` function in Jax. We will skip this for the sake of brevity.

In distributed-memory contexts, this can be implemented using different devices, and communication between the individual devices is required in two steps:
1. We have to **communicate the input vector $x$ to all devices**.
2. We have to **communicate and sum up the outputs of the individual devices**.

In shared-memory contexts, no communication is required, but we have to synchronize the computation.

For the computation of the gradient of the loss function, we have to compute the gradient of the loss function $\mathcal{L}(\theta)$ in parallel, which is simple due to the linearity of the gradient:
$$
    \nabla \mathcal{L}(\theta) 
    = 
    \nabla \frac{1}{n} \sum_{i=1}^{n} \left\| \mathcal{N}_\theta (x_i) - y_i \right\|^2
    = 
    \nabla \frac{1}{n} \sum_{i=1}^{n} \left\| \sum_{p=1}^{P} \mathcal{N}_{p,\theta_p} (x_i) - y_i \right\|^2
    = 
    \frac{1}{n} \sum_{i=1}^{n} \nabla \left\| \sum_{p=1}^{P} \mathcal{N}_{p,\theta_p} (x_i) - y_i \right\|^2
$$

Looking at the parameters $\theta_p$ and only one data point of the sum, we have
$$
    \nabla_{\theta_p} L(\theta)
    =
    \nabla_{\theta_p} \left\| \sum_{p=1}^{P} \mathcal{N}_{p,\theta_p} (x_i) - y_i \right\|^2
    =
    \nabla_{\theta_p} \left( \sum_{p=1}^{P} \mathcal{N}_{p,\theta_p} (x_i) - y_i \right)^\top \left( \sum_{p=1}^{P} \mathcal{N}_{p,\theta_p} (x_i) - y_i \right)
$$
and
$$
    \begin{aligned}
        \nabla_{\theta_p} L(\theta) 
        & =
        \nabla_{\theta_p}
        \sum_{p,q} \left( \mathcal{N}_{p,\theta_p} (x_i) \right)^\top \mathcal{N}_{q,\theta_p} (x_i) - 2 \sum_p \left( \mathcal{N}_{p,\theta_p} (x_i) \right)^\top y_i + y_i^\top y_i %\\        
        %& =
        %\sum_{q} \nabla_{\theta_p} \left( \mathcal{N}_{p,\theta_p} (x_i) \right)^\top \mathcal{N}_{q,\theta_p} (x_i) + \sum_{q} \left( \mathcal{N}_{p,%\theta_p} (x_i) \right)^\top \nabla_{\theta_p} \mathcal{N}_{q,\theta_p} (x_i) - 2 \left( \mathcal{N}_{p,\theta_p} (x_i) \right)^\top y_i,
    \end{aligned}
$$
and we observe that **once $\mathcal{N}_{p,\theta_p} (x_i)$ and $\nabla_{\theta_p} \mathcal{N}_{p,\theta_p} (x_i)$ has been computed for each $p$ and communicated to all devices**, the computation of the gradient can be done in parallel.

This means that the **communication required can be considerable**, but computations can be done in parallel.

Before we move on to domain decomposition methods, we only briefly mention the **"model and data parallelization"** concept.

### Model and Data Parallelism

From the previous sections, we could see that 
+ **data parallelism** is particularly useful if the **number of data samples is very large**,
+ **model parallelism** is particularly useful if the **model is very large**.

In the model and data parallelism approach, both the model and the data are split into different parts. Then, the **aforementioned concepts can be combined for more parallelism**. This can be useful if **both the model and the data are very large**, but a really efficient parallel implementation then might require more thought.

## Domain Decomposition Methods

We will now discuss how **domain decomposition approaches** can be employed to **decompose the neural network architecture** itself. This can be useful both 
+ in the context of **function approximation** as well as
+ in the context of **solving differential equations using physics-informed neural networks**.

We will motivate the idea based on domain decomposition methods for partial differential equations and then discuss how this can be applied to neural networks.

### The Alternating Schwarz Method

Historical remarks: The alternating Schwarz method is the earliest domain decomposition method (DDM), which has been invented by H. A. Schwarz and published in 1870:
+ Schwarz used the algorithm to **establish the existence of harmonic functions** with prescribed boundary values on regions with **nonsmooth boundaries**.
+ The regions are constructed recursively by forming unions of pairs of regions starting with "simple" regions for which existence can be established by more elementary means.
+ At the **core of Schwarz's work** is a proof that solving in an **alternating way on the simple subregions** yields an **iterative scheme which converges at a geometric rate**.

<img src="fig/doorknob.png" width="400">

In particular, the alternating Schwarz method can be used to solve the Poisson equation
$$
    \begin{aligned}
        - \Delta u & = f \quad \text{in} \quad \Omega, \\
        u & = 0 \quad \text{on} \quad \partial \Omega,
    \end{aligned}
$$

<img src="fig/doorknob.png" width="400">

Here, the domain $\Omega$ is decomposed into two overlapping subdomains $\Omega_1$ and $\Omega_2$. In the alternating Schwarz method, we **solve a series of local problems** on the subdomains $\Omega_1$ and $\Omega_2$ and **update the solution on the overlapping region** $\Omega_1 \cap \Omega_2$ in each iteration.

In particular, in each iteration of the method, we solve the following local problems:
$$
    \begin{array}{ccc}
        \begin{array}{rcl}
            - {u_1^{(k)}}'' & = & 1 \quad \text{in} \quad \Omega_1, \\
            u_1^{(k)} & = & 0 \quad \text{on} \quad \partial \Omega \cap \partial \Omega_1, \\
            u_1^{(k)} & = & u_2^{(k-1)} \quad \text{on} \quad \partial \Omega_1, \\
        \end{array}
        &
        \qquad \text{and} \qquad 
        &
        \begin{array}{rcl}
            - {u_2^{(k)}}'' & = & 1 \quad \text{in} \quad \Omega_2, \\
            u_2^{(k)} & = & 0 \quad \text{on} \quad \partial \Omega \cap \partial \Omega_2, \\
            u_2^{(k)} & = & u_1^{(k)} \quad \text{on} \quad \partial \Omega_2, \\
        \end{array}
    \end{array}
$$

This means that 
+ the solution of $u_1$ in the $k$-th iteration depends on $u_2$ in the $(k-1)$-th iteration and
+ the solution of $u_2$ in the $k$-th iteration depends on $u_1$ in the $k$-th iteration.

This is the reason why the method is called **alternating Schwarz method**.

### One Dimensional Example of the Alternating Schwarz Method in Python

In order to better understand the concept, we will consider the following simple **one-dimensional example**. Let $\Omega = (0,1)$ and consider the Poisson equation on that domain:
$$
    \begin{aligned}
        - u'' & = 1 \quad \text{in} \quad \Omega, \\
        u & = 0 \quad \text{on} \quad \partial \Omega.
    \end{aligned}
$$

To apply the alternating Schwarz method, we decompose $\Omega$ ito $\Omega_1 = (0,1/2+\delta)$ and $\Omega_2 = (1/2-\delta,1)$.

For the one-dimensional example under consideration, the **solutions of the local problems can be computed analytically**. In particular, for subdomain $\Omega_1$, the solution of
$$
    \begin{aligned}
        - {u_1}'' & = 1 \quad \text{in} \quad \Omega_1, \\
        u_1 (0) & = 0, \\
        u_1 (1/2+\delta) & = a, 
    \end{aligned}
$$
reads can be written as
$$
    u_1 (x) = u_{pde} (x) + u_{bc} (x),
$$
where $u_{bc} (x)$ is **linear and satisfies the boundary conditions** and $u_{pde} (x)$ is the solution of the Poisson equation. 


In particular, we have
$$
    u_{pde} (x) = \frac{1}{2} x (1-x) 
    \quad \Rightarrow \quad
    u_{pde} (0) = 0 
    \quad \text{and} \quad
    - u_{pde}'' (x) = 1.
$$
Then, 
$$
    \begin{aligned}
        && u_{bc} (x) & = \frac{a - u_{pde}(1/2+\delta)}{1/2+\delta} x , \\
        \Rightarrow && {u_{bc}}'' & = 0, \quad
        u_{bc} (0) = 0, 
        \quad \text{and} \quad
        u_{bc} (1/2+\delta) & = a - u_{pde}(1/2+\delta), 
    \end{aligned}
$$ 

As a result, we obtain for $u_1 (x) = u_{pde} (x) + u_{bc} (x)$ that 
$$
    -{u_1}'' = 1,
    \quad
    u_1 (0) = 0,
    \quad \text{and} \quad
    u_1 (1/2+\delta) = a.
$$

Similarly, we can obtain the **solution of the second local problem** 
$$
    \begin{aligned}
        - {u_2}'' & = 1 \quad \text{in} \quad \Omega_1, \\
        u_2 (0) & = 0, \\
        u_2 (1/2+\delta) & = b, 
    \end{aligned}
$$
on subdomain $\Omega_2$:
$$
    u_2 (x) = \frac{1}{2} x (1-x) + \frac{b - u_{pde}(1/2+\delta)}{1/2+\delta} (1 - x).
$$

Now, we will consider a example in Jax. Therefore, we start with the **initial guess**
$$
    u^{(0)} = 0
$$
and
$$
    u_1^{(0)} = u^{(0)} = 0
    \quad \text{and} \quad
    u_2^{(0)} = u^{(0)} = 0.
$$

In [None]:
import numpy as np

def u_pde(x):
    return 0.5*x*(1-x)

def u_1(x, delta, a):
    return u_pde(x) + (a-u_pde(0.5+delta))/(0.5+delta)*x

def u_2(x, delta, b):
    return u_pde(x) + (b-u_pde(0.5+delta))/(0.5+delta)*(1-x)

def alternating_schwarz(delta, n):
    a = 0
    b = u_1(0.5-delta, delta, a)
    for i in range(n-1):
        a = u_2(0.5+delta, delta, b)
        b = u_1(0.5-delta, delta, a)

    X_1 = np.linspace(0, 0.5+delta, 100)
    U_1 = u_1(X_1, delta, a)
    X_2 = np.linspace(0.5-delta, 1, 100)
    U_2 = u_2(X_2, delta, b)
    
    return X_1, U_1, X_2, U_2

Now, we compute and plot the iterates:

In [None]:
%matplotlib ipympl
from ipywidgets import interact
import matplotlib.pyplot as plt

delta = 0.1

x1 = np.linspace(0, 0.5+delta, 100)
x2 = np.linspace(0.5-delta, 1, 100)
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.set_xlim(0., 1)
ax.set_ylim(0., 0.14)
line1a, = ax.plot(x1, 0.*x1, 'r--')
line1b, = ax.plot(x1, 0.*x1, 'r')
line2a, = ax.plot(x2, 0.*x2, 'b--')
line2b, = ax.plot(x2, 0.*x2, 'b')

def update(delta = 0.1, n = 0):
    if n == 0:
        x1 = np.linspace(0, 0.5+delta, 100)
        x2 = np.linspace(0.5-delta, 1, 100)
        line1a.set_xdata(x1)
        line1a.set_ydata(0.*x1)
        line1b.set_xdata(x1)
        line1b.set_ydata(0.*x1)
        line2a.set_xdata(x2)
        line2a.set_ydata(0.*x2)
        line2b.set_xdata(x2)
        line2b.set_ydata(0.*x2)
    else:
        X_1a, U_1a, X_2a, U_2a = alternating_schwarz(delta, n-1)
        X_1b, U_1b, X_2b, U_2b = alternating_schwarz(delta, n)
        if n == 1:
            U_1a = 0.*U_1a
            U_2a = 0.*U_2a
        line1a.set_xdata(X_1a)
        line1a.set_ydata(U_1a)
        line1b.set_xdata(X_1b)
        line1b.set_ydata(U_1b)
        line2a.set_xdata(X_2a)
        line2a.set_ydata(U_2a)
        line2b.set_xdata(X_2b)
        line2b.set_ydata(U_2b)

    fig.canvas.draw_idle()

interact(update, delta=(0.,0.5,0.1), n=(0,20,1));

We observe that the **solution converges to the exact solution of the Poisson equation**. Moreover, a **larger overlap** of the subdomains leads to **faster convergence**.

Moreover, the algoerithm is sequential with respect to the solution of the local problems on the subdomains. This is because the solution of the first subdomain problem requires the solution of the other subdomain problem from the previous iteration:
$$
    \begin{aligned}
        - {u_1^{(k)}}'' & = 1 \quad \text{in} \quad \Omega_1, \\
        u_1^{(k)} & = 0 \quad \text{on} \quad \partial \Omega \cap \partial \Omega_1, \\
        u_1^{(k)} & = u_2^{(k-1)} \quad \text{on} \quad \partial \Omega_1, \\
    \end{aligned}
$$
and vice versa:
$$
    \begin{aligned}
        - {u_2^{(k)}}'' & = 1 \quad \text{in} \quad \Omega_2, \\
        u_2^{(k)} & = 0 \quad \text{on} \quad \partial \Omega \cap \partial \Omega_2, \\
        u_2^{(k)} & = u_1^{(k)} \quad \text{on} \quad \partial \Omega_2, \\
    \end{aligned}
$$

Can it be parallelized? How do we have to modify the algorithm?

### The Parallel Schwarz Algorithm

Next, we extend this idea to the parallel Schwarz algorithm. In the parallel Schwarz algorithm, we solve the local problems on the subdomains in parallel, which only requires a minor modification of the algorithm. In particular, in each iteration of the method, we solve the following local problems:
$$
    \begin{aligned}
        - {u_1^{(k)}}'' & = 1 \quad \text{in} \quad \Omega_1, \\
        u_1^{(k)} & = 0 \quad \text{on} \quad \partial \Omega \cap \partial \Omega_1, \\
        u_1^{(k)} & = u_2^{(k-1)} \quad \text{on} \quad \partial \Omega_1, \\
    \end{aligned}
$$
and
$$
    \begin{aligned}
        - {u_2^{(k)}}'' & = 1 \quad \text{in} \quad \Omega_2, \\
        u_2^{(k)} & = 0 \quad \text{on} \quad \partial \Omega \cap \partial \Omega_2, \\
        u_2^{(k)} & = u_1^{(k-1)} \quad \text{on} \quad \partial \Omega_2, \\
    \end{aligned}
$$

The two problems are now completely parallelized because the solution of the first subdomain problem does not depend on the solution of the other subdomain problem from the previous iteration.

### One Dimensional Example of the Parallel Schwarz Method in Python

Based on the previously derived analytical solutions, we can now directly **implement the parallel Schwarz algorithm**.

In [None]:
import numpy as np

def u_pde(x):
    return 0.5*x*(1-x)

def u_1(x, delta, a):
    return u_pde(x) + (a-u_pde(0.5+delta))/(0.5+delta)*x

def u_2(x, delta, b):
    return u_pde(x) + (b-u_pde(0.5+delta))/(0.5+delta)*(1-x)

def parallel_schwarz(delta, n):
    a_prev = 0
    b_prev = 0
    for i in range(n-1):
        b = u_1(0.5-delta, delta, a_prev)
        a = u_2(0.5+delta, delta, b_prev)
        a_prev = a
        b_prev = b

    X_1 = np.linspace(0, 0.5+delta, 100)
    U_1 = u_1(X_1, delta, a_prev)
    X_2 = np.linspace(0.5-delta, 1, 100)
    U_2 = u_2(X_2, delta, b_prev)
    
    return X_1, U_1, X_2, U_2

Now, we compute and plot the iterates:

In [None]:
%matplotlib ipympl
from ipywidgets import interact
import matplotlib.pyplot as plt

delta = 0.1

x1 = np.linspace(0, 0.5+delta, 100)
x2 = np.linspace(0.5-delta, 1, 100)
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
ax.set_xlim(0., 1)
ax.set_ylim(0., 0.14)
line1a, = ax.plot(x1, 0.*x1, 'r--')
line1b, = ax.plot(x1, 0.*x1, 'r')
line2a, = ax.plot(x2, 0.*x2, 'b--')
line2b, = ax.plot(x2, 0.*x2, 'b')

def update(delta = 0.1, n = 0):
    if n == 0:
        x1 = np.linspace(0, 0.5+delta, 100)
        x2 = np.linspace(0.5-delta, 1, 100)
        line1a.set_xdata(x1)
        line1a.set_ydata(0.*x1)
        line1b.set_xdata(x1)
        line1b.set_ydata(0.*x1)
        line2a.set_xdata(x2)
        line2a.set_ydata(0.*x2)
        line2b.set_xdata(x2)
        line2b.set_ydata(0.*x2)
    else:
        X_1a, U_1a, X_2a, U_2a = parallel_schwarz(delta, n-1)
        X_1b, U_1b, X_2b, U_2b = parallel_schwarz(delta, n)
        if n == 1:
            U_1a = 0.*U_1a
            U_2a = 0.*U_2a     
        line1a.set_xdata(X_1a)
        line1a.set_ydata(U_1a)
        line1b.set_xdata(X_1b)
        line1b.set_ydata(U_1b)
        line2a.set_xdata(X_2a)
        line2a.set_ydata(U_2a)
        line2b.set_xdata(X_2b)
        line2b.set_ydata(U_2b)

    fig.canvas.draw_idle()

interact(update, delta=(0.,0.5,0.1), n=(0,20,1));

### Comparison of the Two Methods

Finally, we **compare the two methods** next to each other:

In [None]:
%matplotlib ipympl
from ipywidgets import interact
import matplotlib.pyplot as plt

delta = 0.1

x1 = np.linspace(0, 0.5+delta, 100)
x2 = np.linspace(0.5-delta, 1, 100)
fig, ax = plt.subplots(1, 2, figsize=(10, 5))

ax[0].set_xlim(0., 1)
ax[0].set_ylim(0., 0.14)
as_line1a, = ax[0].plot(x1, 0.*x1, 'r--')
as_line1b, = ax[0].plot(x1, u_1(x1, delta, 0), 'r')
as_line2a, = ax[0].plot(x2, 0.*x2, 'b--')
as_line2b, = ax[0].plot(x2, 0.*x2, 'b')

ax[1].set_xlim(0., 1)
ax[1].set_ylim(0., 0.14)
ps_line1a, = ax[1].plot(x1, 0.*x1, 'r--')
ps_line1b, = ax[1].plot(x1, u_1(x1, delta, 0), 'r')
ps_line2a, = ax[1].plot(x2, 0.*x2, 'b--')
ps_line2b, = ax[1].plot(x2, u_2(x2, delta, 0), 'b')

def update(delta = 0.1, n = 0):
    if n == 0:
        x1 = np.linspace(0, 0.5 + delta, 100)
        x2 = np.linspace(0.5-delta, 1, 100)

        as_line1a.set_xdata(x1)
        as_line1a.set_ydata(0.*x1)
        as_line1b.set_xdata(x1)
        as_line1b.set_ydata(0.*x1)
        as_line2a.set_xdata(x2)
        as_line2a.set_ydata(0.*x2)
        as_line2b.set_xdata(x2)
        as_line2b.set_ydata(0.*x2)

        ps_line1a.set_xdata(x1)
        ps_line1a.set_ydata(0.*x1)
        ps_line1b.set_xdata(x1)
        ps_line1b.set_ydata(0.*x1)
        ps_line2a.set_xdata(x2)
        ps_line2a.set_ydata(0.*x2)
        ps_line2b.set_xdata(x2)
        ps_line2b.set_ydata(0.*x2)
    else:
        as_X_1a, as_U_1a, as_X_2a, as_U_2a = alternating_schwarz(delta, n-1)
        as_X_1b, as_U_1b, as_X_2b, as_U_2b = alternating_schwarz(delta, n)
        if n == 1:
            as_U_1a = 0.*as_U_1a
            as_U_2a = 0.*as_U_2a
        as_line1a.set_xdata(as_X_1a)
        as_line1a.set_ydata(as_U_1a)
        as_line1b.set_xdata(as_X_1b)
        as_line1b.set_ydata(as_U_1b)
        as_line2a.set_xdata(as_X_2a)
        as_line2a.set_ydata(as_U_2a)
        as_line2b.set_xdata(as_X_2b)
        as_line2b.set_ydata(as_U_2b)

        ps_X_1a, ps_U_1a, ps_X_2a, ps_U_2a = parallel_schwarz(delta, n-1)
        ps_X_1b, ps_U_1b, ps_X_2b, ps_U_2b = parallel_schwarz(delta, n)
        if n == 1:
            ps_U_1a = 0.*ps_U_1a
            ps_U_2a = 0.*ps_U_2a
        ps_line1a.set_xdata(ps_X_1a)
        ps_line1a.set_ydata(ps_U_1a)
        ps_line1b.set_xdata(ps_X_1b)
        ps_line1b.set_ydata(ps_U_1b)
        ps_line2a.set_xdata(ps_X_2a)
        ps_line2a.set_ydata(ps_U_2a)
        ps_line2b.set_xdata(ps_X_2b)
        ps_line2b.set_ydata(ps_U_2b)

    fig.canvas.draw_idle()

interact(update, delta=(0.,0.5,0.1), n=(0,20,1));

We clearly observe that 
+ the alternating Schwarz methods **converges much faster in terms of the number of iterations**, 
+ whereas in the parallel Schwarz method, both subdomain solves can be carried out in parallel. 

In practice, we observe that the **alternating Schwarz method converges exactly twice as fast as the parallel Schwarz method**. Therefore, if each of the subdomain problems is solved on a separate device, they will converge in the same time, if we assume that the communication needed for both methods is comparable.

### Final Comments on Domain Decomposition Methods

The parallel Schwarz iteration becomes particularly interesting in the case of **more than two subdomains**, allowing for the solution of very large problems in parallel:
+ If the problem size is increased at the same rate as the number of subdomains, the **cost for each local solve remains the same**.
+ One can observe: in a parallel computation, the computing time of each iteration of the the alternating Schwarz method **increases linearly with the number of subdomains**, whereas the time of the parallel Schwarz method **remains constant** (neglecting communication).
+ Unfortunately, the **convergence rate of the parallel Schwarz iteration** will **decrease when the number of subdomains is increased**, because information can only travel to the neighboring subdomains within one iteration.


Moreover, there are also other domain decomposition methods, in particular a class of methods that are based on a **nonoverlapping decomposition of the computational domain**. 

_Here, we will only use the idea of classical domain decomposition solvers as a motivation to design decomposed neural network architectures; they will be covered in more detail in the **Computational Fluid Dynamics** course._

## Domain Decomposition for Function Approximation Using Neural Networks

In this section, we will discuss about how to decompose neural networks to **approximate (nonlinear) functions**. In the final part of this lecture, we will then move on to discuss how to **decompose neural networks to solve partial differential equations using physics-informed neural networks**, which will require additional care.

### Training Performance of Neural Networks For Approximating Functions with Varying Frequencies

We have already observed that function approximation using neural networks is more challenging if the function to be approximated is highly oscillating. We will perform a more systematic study using the `FeedForwardNN` class from above.

In particular, we will investigate what happens if we increase the frequency of the function to be approximated the same time as the width of a neural network with a single hidden layer. In this case, the function approximation should always be possible with the same accuracy. 

Let us investigate this using a Python example. We first define some required functions:

In [None]:
# Create a neural network
key = random.PRNGKey(0)
layer_sizes = [1, 10, 1] 
nn1 = FeedForwardNN(layer_sizes, key)

# Define the loss function (Mean Squared Error)
squared_error = jit(lambda params, x, y: (nn1.forward(params, x) - y) ** 2)
squared_errors_batch = vmap(squared_error, in_axes=(None, 0, 0))
@jit
def loss_fn(params, x, y):
    return jnp.mean(squared_errors_batch(params, x, y))

# Define training step with Adam optimizer
def train_step(params, opt_state, x, y):    
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

 We start with the function $\sin(\pi x)$:

In [None]:
# sin(pi*x) function
x = jnp.linspace(-1.0, 1.0, 1000)
y1 = jnp.sin(jnp.pi*x)

# Initialize the Adam optimizer with the learning rate schedule
learning_rate_schedule = optax.piecewise_constant_schedule(
    init_value=0.1,
    boundaries_and_scales={500: 0.1}
)
optimizer = optax.adam(learning_rate=learning_rate_schedule)
opt_state = optimizer.init(nn1.params)
losses1 = []

# Training loop
max_iterations1 = 20000

# Training loop with tqdm progress bar
pbar = trange(max_iterations1, desc="Training", leave=True)
for epoch in pbar:
    nn1.params, opt_state, current_loss = train_step(nn1.params, opt_state, x, y1)
    losses1.append(current_loss)
    if epoch % 100 == 0:
        pbar.set_postfix(loss=current_loss)
    if current_loss < 0.001:
        break

Next, we consider the function $\sin(4 \pi x)$:

In [None]:
# sin(4*pi*x) function
y2 = jnp.sin(4*jnp.pi*x)

# Create a neural network
key = random.PRNGKey(0)
layer_sizes = [1, 40, 1] 
nn2 = FeedForwardNN(layer_sizes, key)

# Define the loss function (Mean Squared Error)
squared_error = jit(lambda params, x, y: (nn2.forward(params, x) - y) ** 2)
squared_errors_batch = vmap(squared_error, in_axes=(None, 0, 0))
@jit
def loss_fn(params, x, y):
    return jnp.mean(squared_errors_batch(params, x, y))

# Define training step with Adam optimizer
def train_step(params, opt_state, x, y):    
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

# Initialize the Adam optimizer with the learning rate schedule
learning_rate_schedule = optax.piecewise_constant_schedule(
    init_value=0.1,
    boundaries_and_scales={2000: 0.1, 6000: 0.1}
)
optimizer = optax.adam(learning_rate=learning_rate_schedule)
opt_state = optimizer.init(nn2.params)
losses2 = []

# Training loop
max_iterations2 = 20000

# Training loop with tqdm progress bar
pbar = trange(max_iterations2, desc="Training", leave=True)
for epoch in pbar:
    nn2.params, opt_state, current_loss = train_step(nn2.params, opt_state, x, y2)
    losses2.append(current_loss)
    if epoch % 100 == 0:
        pbar.set_postfix(loss=current_loss)
    if current_loss < 0.001:
        break

Finally, we consider the function $\sin(16 \pi x)$:

In [None]:
# sin(16*pi*x) function
y3 = jnp.sin(16*jnp.pi*x)

# Create a neural network
key = random.PRNGKey(0)
layer_sizes = [1, 160, 1] 
nn3 = FeedForwardNN(layer_sizes, key)

# Define the loss function (Mean Squared Error)
squared_error = jit(lambda params, x, y: (nn3.forward(params, x) - y) ** 2)
squared_errors_batch = vmap(squared_error, in_axes=(None, 0, 0))
@jit
def loss_fn(params, x, y):
    return jnp.mean(squared_errors_batch(params, x, y))

# Define training step with Adam optimizer
def train_step(params, opt_state, x, y):    
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

# Initialize the Adam optimizer with the learning rate schedule
learning_rate_schedule = optax.piecewise_constant_schedule(
    init_value=0.01,
)
optimizer = optax.adam(learning_rate=learning_rate_schedule)
opt_state = optimizer.init(nn3.params)
losses3 = []

# Training loop
max_iterations3 = 20000

# Training loop with tqdm progress bar
pbar = trange(max_iterations3, desc="Training", leave=True)
for epoch in pbar:
    nn3.params, opt_state, current_loss = train_step(nn3.params, opt_state, x, y3)
    losses3.append(current_loss)
    if epoch % 100 == 0:
        pbar.set_postfix(loss=current_loss)
    if current_loss < 0.001:
        break

Now, we plot all the results:

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

# Create a figure with two subplots
fig, axs = plt.subplots(3, 2, figsize=(20, 10))

# Plot the loss over the number of iterations for nn1
axs[0, 0].set_yscale('log')
axs[0, 0].plot(range(len(losses1)), losses1, label='Loss')
axs[0, 0].set_xlabel('Iteration')
axs[0, 0].set_ylabel('Loss')
axs[0, 0].legend()

# Plot the results for nn1
axs[0, 1].plot(x, y1, label='Sine function')
axs[0, 1].plot(x, nn1.predict(x).reshape(y1.shape), label='Neural network')
axs[0, 1].legend()

# Plot the loss over the number of iterations for nn2
axs[1, 0].set_yscale('log')
axs[1, 0].plot(range(len(losses2)), losses2, label='Loss')
axs[1, 0].set_xlabel('Iteration')
axs[1, 0].set_ylabel('Loss')
axs[1, 0].legend()

# Plot the results for nn2
axs[1, 1].plot(x, y2, label='Sine function')
axs[1, 1].plot(x, nn2.predict(x).reshape(y2.shape), label='Neural network')
axs[1, 1].legend()

# Plot the loss over the number of iterations for nn3
axs[2, 0].set_yscale('log')
axs[2, 0].plot(range(len(losses3)), losses3, label='Loss')
axs[2, 0].set_xlabel('Iteration')
axs[2, 0].set_ylabel('Loss')
axs[2, 0].legend()

# Plot the results for nn3
axs[2, 1].plot(x, y3, label='Sine function')
axs[2, 1].plot(x, nn3.predict(x).reshape(y3.shape), label='Neural network')
axs[2, 1].legend()

# Show the plots
plt.show()

So, the it is **not sufficient to scale up the number of neurons with the frequency of the function to be approximated**. However, there is a simple way to improve the approximation performance of the neural network, based on the idea of domain decomposition.

### A Simple Domain Decomposition Approach for Neural Networks

The idea of domain decomposition is to **split the problem into smaller subproblems**, based on a **decomposition of the spatial domain**. Here, we employ a domain decomposition approach to **decompose the input domain**, that is the one-dimensional interval $[-1,1]$, into $P$ subintervals.

Let us assume that a single-layer neural network 
$$
    \mathcal{N}_\theta (x) = A^\top \sigma \big( W x + b \big),
$$
where $W \in \mathbb{R}^{m}$, $b \in \mathbb{R}^m$, and $A \in \mathbb{R}^m$, can approximate the function $\sin (\pi x)$ well on the interval $[-1,1]$. 

Then, we would assume that a neural network
$$
    \mathcal{N}_{\hat \theta} (x) = \hat A^\top \sigma \big(\hat W x + \hat b \big),
$$
with $\hat W \in \mathbb{R}^{km}$, $\hat b \in \mathbb{R}^{km}$, and $\hat A \in \mathbb{R}^{km}$, can approximate the function $\sin (k \pi x)$ well. But as we have seen, the situation is not as simple as that.

Therefore, let us first **decompose the interval** $[-1,1]$ into $k$ subintervals $I_1, \ldots, I_k$, with 
$$
    I_j = \big[ \underbrace{\frac{2(j-1)}{k}-1}_{=: x_j^s}, \underbrace{\frac{2j}{k}-1}_{=: x_j^e} \big)
$$
In particular, let
$$
    \omega_j 
    =
    \begin{cases}
        1 & \text{if} \quad x \in I_j, \\
        0 & \text{otherwise}.
    \end{cases}
$$

Note that these functions are a **partition of unity**, i.e., 
$$
    \sum_{j=1}^{k} \omega_j (x) = 1
$$ 
for all $x \in [0,1]$.

Then, we define a **new neural network architecture** as follows
$$
    \mathcal{N}_{\theta} (x) = \sum_{i=j}^{k} \omega_j (x) \mathcal{N}_{j,\theta_j} (x),
$$
where $\mathcal{N}_{j,\theta_j} (x)$ is a neural network with an architecture that can approximate a single $\sin (x)$ well, that is,
$$
    \mathcal{N}_{i,\theta_i} (x) = A_j^\top \sigma \big( W_j x + b_j \big),
$$
where $W_j \in \mathbb{R}^{m}$, $b_j \in \mathbb{R}^m$, and $A_j \in \mathbb{R}^m$.

This means that the support of the function
$$
    \mathcal{F}_{j,\theta_j} (x) = \omega_j (x) \mathcal{N}_{j,\theta_j} (x)
$$
is only inside the interval $I_j$. Outside the interval, the function will be zero. 

Hence, if we evaluate the loss function 
$$
    \mathcal{L}(\theta) = \frac{1}{n} \sum_{i=1}^{n} \left\| \mathcal{N}_\theta (x_i) - y_i \right\|^2
$$
with this network architecture, this can be rewritten as
$$
    \mathcal{L}(\theta) 
    =
    \frac{1}{n} \sum_j \left\| \sum_{i=1}^n \mathcal{F}_{j,\theta_j} (x_i) - y_i \right\|^2
    =
    \frac{1}{n} \sum_j \left\| \sum_{x_i \in I_j} \mathcal{F}_{j,\theta_j} (x_i) - y_i \right\|^2.
$$
In the last step, we have used that the support of the function $\mathcal{F}_{j,\theta_j} (x)$ is inside the interval $I_j$. 

This means that, for an **efficient implementation**, we do not even have to evaluate $\mathcal{F}_{j,\theta_j} (x_i)$ outside the interval $I_j$. For simplicity, we will not implement this here, but it is a simple modification of the code.

Moreover, the network corresponding to a subinterval only has to consider **data points inside $I_j$**. This means that, for the network to learn the function on this interval well, we should normalize the input data to $\mathcal{F}_{i,\theta_i}$, for instance, to be $[-1,1]$; this will also **depend on the weight initialization**. As a consequence, we introduce a **normalization step**
$$
    {\rm norm}(x) = \frac{x - \frac{1}{2}\left(x_j^s + x_j^e\right)}{\frac{1}{2}\left(x_j^e - x_j^s\right)}
$$
and consider the final neural network architecture 
$$
    \mathcal{N}_{\theta} (x) = \sum_{i=j}^{k} \omega_j (x) \mathcal{N}_{j,\theta_j} ({\rm norm}(x)).
$$

As a result, **every individual network will only see a a single interval**, for instance, a single period of the sine.

### Python Example of Domain Decomposition for Function Approximation

We will now implement exactly the architecture proposed in the previous section. We will consider the function $\sin (16 \pi x)$ and decompose the interval $[-1,1]$ into $k$ subintervals. We will then train the network to approximate the function using the domain decomposition approach.

First, we define the neural network architecture:

In [None]:
# Generate sample data
k = 16
x = jnp.linspace(-1.0, 1.0, 1000)
y = jnp.sin(k * jnp.pi * x)

# Define the domain decomposition neural network
class DomainDecomposedNN:
    def __init__(self, num_subdomains, layer_sizes, key, activation_fn=jax.nn.tanh):
        self.num_subdomains = num_subdomains
        self.subnets = [FeedForwardNN(layer_sizes, random.split(key, num_subdomains)[i], activation_fn) for i in range(num_subdomains)]
        self.params = [subnet.params for subnet in self.subnets]

    def forward(self, params, x):
        subdomain_size = 2.0 / self.num_subdomains
        outputs = []
        for i in range(self.num_subdomains):
            mask = jnp.where((x >= -1.0 + i * subdomain_size) & (x < -1.0 + (i + 1) * subdomain_size), 1.0, 0.0)
            norm_x = 2 * (x - (-1.0 + (i + 0.5) * subdomain_size)) / subdomain_size
            outputs.append(mask * self.subnets[i].forward(params[i], norm_x))
        return jnp.sum(jnp.array(outputs))

Then, we define the loss function:

In [None]:
# Define the loss function (Mean Squared Error)
squared_error = jit(lambda params, x, y: (nn.forward(params, x) - y) ** 2)
squared_errors_batch = vmap(squared_error, in_axes=(None, 0, 0))
@jit
def loss_fn(params, x, y):    
    return jnp.mean(squared_errors_batch(params, x, y))

# Define training step with Adam optimizer
def train_step(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

Now, we train the neural network:

In [None]:
# Initialize the domain decomposed neural network
key = random.PRNGKey(0)
layer_sizes = [1, 10, 1]
nn = DomainDecomposedNN(k, layer_sizes, key)

# Create a learning rate schedule
learning_rate_schedule = optax.piecewise_constant_schedule(
    init_value=0.1,
    boundaries_and_scales={300: 0.5}
)

# Initialize the Adam optimizer with the learning rate schedule
optimizer = optax.adam(learning_rate=learning_rate_schedule)
opt_state = optimizer.init(nn.params)

# Training loop
losses = []
max_iterations = 1000

# Training loop with tqdm progress bar
pbar = trange(max_iterations, desc="Training", leave=True)
for epoch in pbar:
    nn.params, opt_state, current_loss = train_step(nn.params, opt_state, x, y)
    losses.append(current_loss)
    if epoch % 100 == 0:
        pbar.set_postfix(loss=current_loss)
    if current_loss < 0.001:
        break

Finally, we plot the results:

In [None]:
# Create a figure with two subplots
fig, axs = plt.subplots(1, 2, figsize=(10, 2))

# Plot the loss over the number of iterations
axs[0].set_yscale('log')
axs[0].plot(range(len(losses)), losses, label='Loss')
axs[0].set_xlabel('Iteration')
axs[0].set_ylabel('Loss')
axs[0].legend()

# Plot the results
predictions = vmap(nn.forward, in_axes=(None, 0))(nn.params, x)
axs[1].plot(x, y, label='Sine function')
axs[1].plot(x, jnp.squeeze(predictions), label='Neural network')
axs[1].legend()

# Show the plots
plt.show()

## Domain Decomposition for Physics-Informed Neural Networks

Finally, we apply the decomposition approach from the previous section in the context of physics-informed neural networks. Therefore, consider the ordinary differential equation
$$
    \begin{aligned}
        u'(x) & = k \pi \cos(k \pi x), \quad x \in [-1,1], \\
        u(0) & = 0.
    \end{aligned}
$$

Then, the exact solution is given by
$$
    u(x) = \sin(k \pi x).
$$

### Python Example of Domain Decomposition for Physics-Informed Neural Networks Without Domain Decomposition

Let us first try to approximate the solution using a singe physics-informed neural network without domain decomposition. In particular, we employ a simple neural network with a single hidden layer $\mathcal{N}_\theta (x)$.

For simplicity, we enforce the boundary condition $u(0) = 0$ via hard constraints:
$$
    u_{\theta} (x) = \tanh(k \pi x) N_\theta (x)
$$

Therefore, in total, we employ the physics-informed loss function
$$
    \mathcal{L}(\theta) = \frac{1}{n} \sum_{i=1}^{n} \left\| {u_{\theta}}' (x_i) - k \pi \cos(k \pi x_i) \right\|^2.
$$

In [None]:
from jax import grad

# Generate sample data
k = 16
x = jnp.linspace(-1.0, 1.0, 1000)
y = jnp.sin(k * jnp.pi * x)

# Initialize the domain decomposed neural network
key = random.PRNGKey(0)
layer_sizes = [1, 160, 1]
nn = FeedForwardNN(layer_sizes, key)

# Define the neural network function u
sigma = jit(lambda x: jax.nn.tanh(k * jnp.pi * x))
u_nn = jit(lambda params, x: sigma(x) * nn.forward(params, x).squeeze())
u_nn_batch = vmap(u_nn, (None, 0))

# Compute gradients
u_nn_x = grad(u_nn, argnums=1)

# Define the residual function
squared_residual = jit(lambda params, x: (u_nn_x(params, x) - k * jnp.pi * jnp.cos(k * jnp.pi * x)) ** 2)
squared_residual_batch = vmap(squared_residual, (None, 0))

# Define the physics-informed loss function
@jit
def physics_informed_loss(params, x):
    physics_loss = jnp.mean(squared_residual_batch(params, x))
    return physics_loss

# Define training step with Adam optimizer
def train_step(params, opt_state, x):
    loss, grads = jax.value_and_grad(physics_informed_loss)(params, x)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

# Create a learning rate schedule
learning_rate_schedule = optax.piecewise_constant_schedule(
    init_value=0.01,
    # boundaries_and_scales={200: 0.1}
)

We set up the training process:

In [None]:
# Initialize the Adam optimizer with the learning rate schedule
optimizer = optax.adam(learning_rate=learning_rate_schedule)
opt_state = optimizer.init(nn.params)

# Training loop
losses = []
max_iterations = 20000

x = jnp.linspace(-1.0, 1.0, 1000)

# Training loop with tqdm progress bar
pbar = trange(max_iterations, desc="Training", leave=True)
for epoch in pbar:
    nn.params, opt_state, current_loss = train_step(nn.params, opt_state, x)
    losses.append(current_loss)
    if epoch % 100 == 0:
        pbar.set_postfix(loss=current_loss)
    if current_loss < 0.001:
        break

And finally, we plot the results:

In [None]:
# Create a figure with two subplots
fig, axs = plt.subplots(1, 2, figsize=(10, 2))

# Plot the loss over the number of iterations
axs[0].set_yscale('log')
axs[0].plot(range(max_iterations), losses, label='Loss')
axs[0].set_xlabel('Iteration')
axs[0].set_ylabel('Loss')
axs[0].legend()

# Plot the results
predictions = u_nn_batch(nn.params, x)
axs[1].plot(x, y, label='Sine function')
axs[1].plot(x, jnp.squeeze(predictions), label='Neural network')
axs[1].set_ylim([-1, 1])
axs[1].legend()

# Show the plots
plt.show()

### First Python Example of Domain Decomposition for Physics-Informed Neural Networks

**Combining the the domain decomposition approach with physics-informed neural networks**, we decompose the interval $[-1,1]$ into $k$ subintervals $I_1,\ldots,I_k$ and use the same acticture as before:
$$
    \mathcal{N}_{\theta} (x) = \sum_{i=j}^{k} \omega_j (x) \mathcal{N}_{j,\theta_j} ({\rm norm}(x)),
$$
with
$$
    \omega_j 
    =
    \begin{cases}
        1 & \text{if} \quad x \in I_j, \\
        0 & \text{otherwise}.
    \end{cases}
$$

With strong enforcement of boundary constraints, we obtain
$$
    u_{\theta} (x) = \tanh(k \pi x) \sum_{i=j}^{k} \omega_j (x) \mathcal{N}_{j,\theta_j} ({\rm norm}(x))
$$
and the same physics-informed loss function
$$
    \mathcal{L}(\theta) = \frac{1}{n} \sum_{i=1}^{n} \left\| {u_{\theta}}' (x_i) - k \pi \cos(k \pi x_i) \right\|^2.
$$

We set up the loss function and training step:

In [None]:
from jax import jit, grad

# Generate sample data
k = 16
x = jnp.linspace(-1.0, 1.0, 1000)
y = jnp.sin(k * jnp.pi * x)

# Initialize the domain decomposed neural network
key = random.PRNGKey(0)
layer_sizes = [1, 10, 1]
nn = DomainDecomposedNN(k, layer_sizes, key)

# Define the neural network function u
sigma = jit(lambda x: jax.nn.tanh(k * jnp.pi * x))
u_nn = jit(lambda params, x: sigma(x) * nn.forward(params, x))
u_nn_batch = vmap(u_nn, (None, 0))

# Compute gradients
u_nn_x = grad(u_nn, argnums=1)

# Define the residual function
squared_residual = jit(lambda params, x: (u_nn_x(params, x) - k * jnp.pi * jnp.cos(k * jnp.pi * x)) ** 2)
squared_residual_batch = vmap(squared_residual, (None, 0))

# Define the physics-informed loss function
@jit
def physics_informed_loss(params, x):
    physics_loss = jnp.mean(squared_residual_batch(params, x))
    return physics_loss

# Define training step with Adam optimizer
def train_step(params, opt_state, x):
    loss, grads = jax.value_and_grad(physics_informed_loss)(params, x)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

We train the model:

In [None]:
# Create a learning rate schedule
learning_rate_schedule = optax.piecewise_constant_schedule(
    init_value=0.01,
    # boundaries_and_scales={200: 0.1}
)

# Initialize the Adam optimizer with the learning rate schedule
optimizer = optax.adam(learning_rate=learning_rate_schedule)
opt_state = optimizer.init(nn.params)

# Training loop
losses = []
max_iterations = 5000

x = jnp.linspace(-1.0, 1.0, 1000)

# Training loop with tqdm progress bar
pbar = trange(max_iterations, desc="Training", leave=True)
for epoch in pbar:
    nn.params, opt_state, current_loss = train_step(nn.params, opt_state, x)
    losses.append(current_loss)
    if epoch % 100 == 0:
        pbar.set_postfix(loss=current_loss)
    if current_loss < 0.001:
        break

Finally, we plot the results:

In [None]:
# Create a figure with two subplots
fig, axs = plt.subplots(1, 2, figsize=(10, 2))

# Plot the loss over the number of iterations
axs[0].set_yscale('log')
axs[0].plot(range(max_iterations), losses, label='Loss')
axs[0].set_xlabel('Iteration')
axs[0].set_ylabel('Loss')
axs[0].legend()

# Plot the results
predictions = u_nn_batch(nn.params, x)
axs[1].plot(x, y, label='Sine function')
axs[1].plot(x, jnp.squeeze(predictions), label='Neural network')
axs[1].set_ylim([-1, 1])
axs[1].legend()

# Show the plots
plt.show()

Somehow, this does not seem to work properly. Recalling the Schwarz methods, the **overlap was essential for transferring information between the subdomains**. In particular, the neural network function
$$
    \mathcal{N}_{\theta} (x) = \sum_{i=j}^{k} \omega_j (x) \mathcal{N}_{j,\theta_j} ({\rm norm}(x)),
$$
is **not necessarily continuous** accross the whole domain $[-1,1]$ because at the interface $\hat x_{j+1} = I_j \cap I_{j+1}$, we can have that
$$
    \lim_{x \nearrow \hat x_{j+1}^-} \omega_j (x) \mathcal{N}_{j,\theta_j} ({\rm norm}(x)) \neq \lim_{x \searrow \hat x_{j+1}^+} \omega_j (x) \mathcal{N}_{j+1,\theta_{j+1}} ({\rm norm}(x)),
$$
and there is no constaint on the neural network functions $\mathcal{N}_{j,\theta_j}$ and $\mathcal{N}_{j+1,\theta_{j+1}}$ to balance that. 

In particular, the differential equation
$$
    u'(x) = k \pi \cos(k \pi x),
$$
and hence the physics loss, only constrains the derivative of the neural network function. This is not the case for function approximation, where the loss constrains the function values itself.

### Overlapping Schwarz Domain Decomposition for Physics-Informed Neural Networks

In the final part of today's lecture, we will show that using an **overlapping domain decomposition can solve this problem**. Therefore, we will now employ a neural network architecture which is based on an overlapping domain decomposition, which also requires us to modify the partition of unity functions $\omega_j$.

This is based on the papers:
> Moseley, B., Markham, A., & Nissen-Meyer, T. (2023). Finite basis physics-informed neural networks (FBPINNs): a scalable domain decomposition approach for solving differential equations. Advances in Computational Mathematics, 49(4), 62.

> Dolean, V., Heinlein, A., Mishra, S., & Moseley, B. (2024). Multilevel domain decomposition-based architectures for physics-informed neural networks. Computer Methods in Applied Mechanics and Engineering, 429, 117116.

Let us first decompose the interval $[-1,1]$ into $k$ overlapping subintervals $I_1, \ldots, I_k$, with 
$$
    I_j = \big[ \underbrace{\frac{2(j-1) - \delta}{k}-1}_{x_j^s}, \underbrace{\frac{2j+\delta}{k}-1}_{x_j^e} \big] \cap [-1,1].
$$
where $\delta$ indicates the relative overlap of the subdomains. In order to obtain a partition of unity on these overlapping intervals, we define the partition of unity functions as follows
$$
    \omega_j (x) 
    =
    \begin{cases}
        0 & \text{if} \quad x < x_j^s, \\
        \frac{1}{2}\left(1-\cos\left(k \pi \frac{x - x_j^s}{2\delta}\right)\right) & \text{if} \quad x_j^s \leq x < x_j^s+\frac{2\delta}{k} \in I_j, \\
        1 & \text{if} \quad x_j^s+\frac{2\delta}{k} \leq x < x_j^e-\frac{2\delta}{k}, \\
        \frac{1}{2}\left(1+\cos\left(k \pi \frac{x - x_j^e}{2\delta} + \pi \right)\right) & \text{if} \quad x_j^e-\frac{2\delta}{k} \leq x < x_j^e \in I_j, \\
        0 & \text{if} \quad x_j^e \leq x.
    \end{cases}
$$

This function looks like:

In [None]:
import jax.numpy as jnp
import matplotlib.pyplot as plt

# Parameters
k = 16
delta = 0.5
j = 2

def omega_j(x, j, k, delta):
    if j == 0:
        x_js = -1 - 2 * delta
    else :
        x_js = (2 * j - delta) / k - 1
    if j == k-1:
        x_je = 1 + 2 * delta
    else :
        x_je = (2 * (j + 1) + delta) / k - 1
    return jnp.where(
        (x < x_js) | (x >= x_je), 0.0,
        jnp.where(
            (x_js <= x) & (x < x_js + 2 * delta / k),
            0.5*(1 - jnp.cos(jnp.pi * k * (x - x_js) / (2 * delta))),
            jnp.where(
                (x_js + 2 * delta / k <= x) & (x < x_je - 2 * delta / k),
                1.0,
                jnp.where(
                    (x_je - 2 * delta / k <= x) & (x < x_je),
                    0.5*(1+jnp.cos(jnp.pi * k * (x - x_je) / (2 * delta) + jnp.pi)),
                    0.0
                )
            )
        )
    )

# Generate x values
x = jnp.linspace(-1.0, 1.0, 1000)
omega_values = omega_j(x, j, k, delta)

# Compute the gradient of omega_j
omega_grad = grad(omega_j, argnums=0)
omega_grad_batch = vmap(omega_grad, in_axes=(0, None, None, None))
omega_grad_values = omega_grad_batch(x, j, k, delta)

# Plot the function and its gradient
fig, ax1 = plt.subplots()

ax1.set_xlabel('x')
ax1.set_ylabel(f'omega_{j}(x)', color='tab:blue')
ax1.plot(x, omega_values, label=f'omega_{j}(x)', color='tab:blue')
ax1.tick_params(axis='y', labelcolor='tab:blue')

ax2 = ax1.twinx()
ax2.set_ylabel(f"omega_{j}'(x)", color='tab:red')
ax2.plot(x, omega_grad_values, label=f"omega_{j}'(x)", color='tab:red')
ax2.tick_params(axis='y', labelcolor='tab:red')

fig.tight_layout()
plt.title(f'Plot of omega_{j}(x) and its gradient with j={j}, k={k}, delta={delta}')
plt.show()

Let us now set up the overlapping domain decomposition-based neural network architecture:

In [None]:
# Define the domain decomposition neural network
class OverlappingDomainDecomposedNN:
    def __init__(self, num_subdomains, overlap, layer_sizes, key, activation_fn=jax.nn.tanh):
        self.num_subdomains = num_subdomains
        self.overlap = overlap
        self.subnets = [FeedForwardNN(layer_sizes, key, activation_fn) for j in range(num_subdomains)]
        self.params = [subnet.params for subnet in self.subnets]
        self.x_js = jnp.array([jnp.maximum((2 * j - overlap) / num_subdomains - 1.0, -1.0) for j in range(num_subdomains)])
        self.x_je = jnp.array([jnp.minimum((2 * (j + 1) + overlap) / num_subdomains - 1.0, 1.0) for j in range(num_subdomains)])

    def forward(self, params, x):
        outputs = []
        for i in range(self.num_subdomains):
            mask = omega_j(x, i, self.num_subdomains, self.overlap)
            norm_x = 2 * (x - (self.x_js[i] + self.x_je[i]) / 2) / (self.x_je[i] - self.x_js[i])
            outputs.append(mask * self.subnets[i].forward(params[i], norm_x))
        return jnp.sum(jnp.array(outputs))

Then, we set up the neural network and loss function:

In [None]:
from jax import jit, grad

# Parameters
k = 16
delta = 0.5

# Generate sample data
x = jnp.linspace(-1.0, 1.0, 1000)
y = jnp.sin(k * jnp.pi * x)

# Initialize the domain decomposed neural network
key = random.PRNGKey(0)
layer_sizes = [1, 20, 1]
nn = OverlappingDomainDecomposedNN(k, delta, layer_sizes, key)

# Define the neural network function u
sigma = jit(lambda x: jax.nn.tanh(k * jnp.pi * x))
u_nn = jit(lambda params, x: nn.forward(params, x) * sigma(x))
u_nn_batch = vmap(u_nn, (None, 0))

# Compute gradients
u_nn_x = grad(u_nn, argnums=1)

# Define the residual function
squared_residual = jit(lambda params, x: (u_nn_x(params, x) - k * jnp.pi * jnp.cos(k * jnp.pi * x)) ** 2)
squared_residual_batch = vmap(squared_residual, (None, 0))

# Define the physics-informed loss function
@jit
def physics_informed_loss(params, x):
    physics_loss = jnp.mean(squared_residual_batch(params, x))
    return physics_loss

Then, we train the model:

In [None]:
# Define training step with Adam optimizer
def train_step(params, opt_state, x):
    loss, grads = jax.value_and_grad(physics_informed_loss)(params, x)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

# Create a learning rate schedule
learning_rate_schedule = optax.piecewise_constant_schedule(
    init_value=0.01,
    boundaries_and_scales={2500: 0.1}
)

# Initialize the Adam optimizer with the learning rate schedule
optimizer = optax.adam(learning_rate=learning_rate_schedule)
opt_state = optimizer.init(nn.params)

# Training loop
losses = []
max_iterations = 2000

# Training loop with tqdm progress bar
pbar = trange(max_iterations, desc="Training", leave=True)
for epoch in pbar:
    nn.params, opt_state, current_loss = train_step(nn.params, opt_state, x)
    losses.append(current_loss)
    if epoch % 100 == 0:
        pbar.set_postfix(loss=current_loss)
    if current_loss < 0.001:
        break

Finally, we plot the results:

In [None]:
# Create a figure with two subplots
fig, axs = plt.subplots(1, 2, figsize=(10, 2))

# Plot the loss over the number of iterations
axs[0].set_yscale('log')
axs[0].plot(range(max_iterations), losses, label='Loss')
axs[0].set_xlabel('Iteration')
axs[0].set_ylabel('Loss')
axs[0].legend()

# Plot the results
predictions = u_nn_batch(nn.params, x)
axs[1].plot(x, y, label='Sine function')
axs[1].plot(x, predictions, label='Neural network')
axs[1].set_ylim([-1, 1])
axs[1].legend()

# Show the plots
plt.show()

### Some Final Remarks on the importance of the overlap

As an **alternative to using an overlapping domain decomposition**, we can also modify the approach of a nonoverlapping domain decomposition. As mentioned before, the problem is the missing transfer of spatial information provided by the overlap. To account for that, we can alternatively enforce continuity of the solution accross the whole domain $[-1,1]$ via additional constraints.

In particular, we could enforce 
$$
    \lim_{x \nearrow \hat x_{j+1}^-} \omega_j (x) \mathcal{N}_{j,\theta_j} ({\rm norm}(x)) \neq \lim_{x \searrow \hat x_{j+1}^+} \omega_j (x) \mathcal{N}_{j+1,\theta_{j+1}} ({\rm norm}(x)),
$$
by adding the constraint to the loss function:
$$
    \left( \omega_j (x) \mathcal{N}_{j,\theta_j} ({\rm norm}(x_i)) - \omega_{j+1} (x) \mathcal{N}_{j+1,\theta_{j+1}} ({\rm norm}(x_i)) \right)^2.
$$
for all 
$$
    x \in \bigcup_{j = 1,\ldots,k-1} ( I_j \cap I_{j+1} ). 
$$

## Thank you for your attention!

## Questions?

## Appendix

### 1D Laplace Example

In [None]:
# Parameters
k = 8
delta = 0.5

# Generate sample data
x = jnp.linspace(0.0, 1.0, 1000)
y = 0.5 * x * (1 - x)

# Initialize the domain decomposed neural network
key = random.PRNGKey(0)
layer_sizes = [1, 20, 1]
nn = OverlappingDomainDecomposedNN(k, delta, layer_sizes, key)

# Define the neural network function u
sigma = jit(lambda x: x * (1 - x))
u_nn = jit(lambda params, x: sigma(x) * nn.forward(params, x).squeeze())
u_nn_batch = vmap(u_nn, (None, 0))

# Compute gradients
u_nn_x = grad(u_nn, argnums=1)
u_nn_xx = grad(u_nn_x, argnums=1)

# Define the residual function
squared_residual = jit(lambda params, x: (u_nn_xx(params, x) + 1.0) ** 2)
squared_residual_batch = vmap(squared_residual, (None, 0))

# Define the physics-informed loss function
@jit
def physics_informed_loss(params, x):
    physics_loss = jnp.mean(squared_residual_batch(params, x))
    return physics_loss

In [None]:
# Define training step with Adam optimizer
def train_step(params, opt_state, x):
    loss, grads = jax.value_and_grad(physics_informed_loss)(params, x)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

# Create a learning rate schedule
learning_rate_schedule = optax.piecewise_constant_schedule(
    init_value=0.1,
    boundaries_and_scales={2500: 0.1}
)

# Initialize the Adam optimizer with the learning rate schedule
optimizer = optax.adam(learning_rate=learning_rate_schedule)
opt_state = optimizer.init(nn.params)

# Training loop
losses = []
max_iterations = 2000

# Training loop with tqdm progress bar
pbar = trange(max_iterations, desc="Training", leave=True)
for epoch in pbar:
    nn.params, opt_state, current_loss = train_step(nn.params, opt_state, x)
    losses.append(current_loss)
    if epoch % 100 == 0:
        pbar.set_postfix(loss=current_loss)
    if current_loss < 0.0001:
        break

In [None]:
# Create a figure with two subplots
fig, axs = plt.subplots(1, 2, figsize=(10, 2))

# Plot the loss over the number of iterations
axs[0].set_yscale('log')
axs[0].plot(range(len(losses)), losses, label='Loss')
axs[0].set_xlabel('Iteration')
axs[0].set_ylabel('Loss')
axs[0].legend()

# Plot the results
predictions = u_nn_batch(nn.params, x)
axs[1].plot(x, y, label='Sine function')
axs[1].plot(x, predictions, label='Neural network')
axs[1].set_xlim([0, 1.0])
axs[1].set_ylim([0, 0.5])
axs[1].legend()

# Show the plots
plt.show()