In [None]:
%load_ext tensorboard

In [None]:
import os
import datetime

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from assets.ml.DeepONetwork import TrunkNN
from assets.ml.DeepONetwork import BranchNN
from assets.ml.DeepONetwork import DeepONET
from assets.ml.DeepONetwork import DeepOPINN

In [None]:
log_dir = os.path.join(".", "assets", "logs", "fits", f"{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

In [None]:
branchHiddenLayers = [tf.keras.layers.Dense(20, activation = 'tanh', name = f"branchNETDense_layer{i+1}") for i in range(5)]
trunkHiddenLayers = [tf.keras.layers.Dense(20, activation = 'tanh', name = f"trunkNETDense_layer{i+1}") for i in range(5)]


branchNET = BranchNN(hiddenLayers=branchHiddenLayers, input_shape=(100,))
trunkNET = TrunkNN(hiddenLayers=trunkHiddenLayers, input_shape=(1,))
deepONET = DeepONET(branchNN=branchNET, trunkNN=trunkNET)

In [None]:
domain = np.linspace(0, 1, 100)

In [None]:
def f(x):
    return np.sin(2*np.pi*x)

def g(x):
    return np.cos(2*np.pi*x)

In [None]:
u = f(domain)
v = g(domain)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 9))

ax.plot(domain, u, label = "$f(x)$")
ax.plot(domain, v, label = "$f^{'}(x)$")
ax.set_xlabel("$x$")
ax.set_ylabel("$f(x)$")
ax.tick_params(axis='x', labelcolor = 'tab:blue')
ax.grid(True, color='lightgrey')
ax.legend()

ax2 = ax.twiny()
ax2.plot(u, v, color="C3", label = "f(x) vs $f^{'}(x)$")
ax2.set_xlabel("$f^{'}(x)$")
ax2.set_ylabel("$f(x)$")
ax2.grid(True, color='darkgrey')
ax2.tick_params(axis='x', labelcolor = 'tab:red')
ax2.legend()

plt.show()

In [None]:
n = 100

branch_input = np.array([f(domain) for i in range(n)]).reshape(-1, 100)
print(branch_input.shape)

trunk_input = domain.reshape(-1, 1)
print(trunk_input.shape)

y = g(domain).reshape(-1, 1)
print(y.shape)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3, beta_1=0.9, beta_2=0.999)
deepONET.compile(optimizer=optimizer, loss='mse', metrics=['mae'])

In [None]:
deepONET.build(input_shape=[(None, 100), (None, 1)])

In [None]:
deepONET.summary()

In [None]:
deepONET.fit(x=[branch_input, trunk_input], y=y, epochs=1500, verbose=True, batch_size=16, validation_split=0.3, callbacks=tensorboard_callback)

In [None]:
%tensorboard --logdir ./assets/logs --port=8080

In [None]:
preds = deepONET.predict([branch_input, trunk_input])

In [None]:
print(preds)
print(preds.shape)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 9))

ax.plot(domain, g(domain), label = "Ground Truth")
ax.plot(domain, preds, label = "Predicted")
ax.set_xlabel("$x$")
ax.set_ylabel("$f(x)$")
ax.tick_params(axis='x', labelcolor = 'tab:blue')
ax.grid(True, color='lightgrey')
ax.legend()

plt.show()