In [None]:
from modules._models_v2 import create_dense_model, WavePinn
from modules.plots import plot_training_loss_linlog, plot_wave_model
from modules.data import simulate_wave
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [None]:
#https://personal.math.ubc.ca/~feldman/m267/separation.pdf
c = 1.0
length = 1.0
n_samples = 2000

def f_u(tx):
    t = tx[:, 0:1]
    x = tx[:, 1:2]
    return tf.sin(5 * np.pi * x) * tf.cos(5 * c * np.pi * t) + \
        2*tf.sin(7 * np.pi * x) * tf.cos(7 * c * np.pi * t)

def f_u_init(tx):
    x = tx[:, 1:2]
    return tf.sin(5 * np.pi * x) + 2*tf.sin(7 * np.pi * x)

def f_du_dt(tx):
    return tf.zeros_like(tx[:, 0:1])

def f_u_bnd(tx):
    return tf.zeros_like(tx[:, 1:2])

In [None]:
(tx_samples, residual), (tx_init, u_init, du_dt_init), (tx_bndry, u_bndry) = simulate_wave(n_samples, f_u_init, f_du_dt, f_u_bnd)

In [None]:
inputs = [tx_samples, tx_init, tx_bndry]
outputs = [f_u(tx_samples), residual, u_init, du_dt_init, u_bndry]

In [None]:
backbone = create_dense_model([128]*3, 'elu', 'he_normal', n_inputs=2, n_outputs=1)
pinn = WavePinn(backbone, c)
scheduler = tf.keras.optimizers.schedules.ExponentialDecay(1e-3, 500, 0.93)
optimizer = tf.keras.optimizers.Adam(scheduler)
pinn.compile(optimizer=optimizer)

In [None]:
history = pinn.fit(inputs, outputs, epochs=10000, batch_size=200, verbose=0)

In [None]:
plot_training_loss_linlog(history.history)

In [None]:
plot_wave_model(pinn.backbone, 0, 1.0, 1)

In [None]:
def plot_wave_at_x(model, x, time, save_path = None) -> None:
    """
    Plot the solution of the wave equation for a given model at a given x coordinate.
    Args:
        model (tf.keras.Model): Model that predicts the solution of the wave equation.
        x (float): x coordinate of the plot.
        time (float): Time frame of the simulation.
        save_path (str, optional): Path to save the plot. Defaults to None.
    """
    t = np.linspace(0, time, 100)
    u = model.predict(np.stack([t, np.full(t.shape, x)], axis=-1), batch_size=1000)
    plt.plot(t, u)
    plt.xlabel('t')
    plt.ylabel('u')
    if save_path:
        plt.savefig(save_path)
    plt.show()

In [None]:
plot_wave_at_x(pinn.backbone, 0.5, 0.5)