In [None]:
# Visualize the results from a ode1_pinn.py run.

In [None]:
from importlib import import_module
import os
import numpy as np
import matplotlib.pyplot as plt
import sys

In [None]:
# Specify the run ID (aka problem name).
runid = "linear"

# Add the subdirectory for the problem to the module search path.
run_path = os.path.join(".", runid)
sys.path.append(run_path)

# Import the problem definition.
p = import_module(runid)

# Read the run hyperparameters.
import hyperparameters as hp

In [None]:
# Load the training and validation points.
x_train = np.loadtxt(os.path.join(runid, "x_train.dat"))
x_val = np.loadtxt(os.path.join(runid, "x_val.dat"))

In [None]:
# Format the axis labels.
xx = np.linspace(0, 1, hp.nx_train)
x_labels = ["%.1f" % x for x in xx]

In [None]:
# Load the loss function histories.
losses = np.loadtxt(os.path.join(runid, "losses.dat"))

In [None]:
# Plot the loss function histories.
plt.semilogy(losses, label="L (total)")
plt.xlabel("Epoch")
plt.ylabel("Loss function")
plt.legend()
plt.grid()
plt.title("Loss function evolution for %s\n$\eta$=%s, H=%s, $n_x$=%s" %
          (runid, hp.learning_rate, hp.H, hp.nx_train))
plt.show()

In [None]:
# Load the trained and validation y-values.
y_train = np.loadtxt(os.path.join(runid, "y_train.dat"))
y_val = np.loadtxt(os.path.join(runid, "y_val.dat"))


In [None]:
# Plot the trained y-values.
plt.plot(x_train, y_train)
plt.title("Trained y(x)")
plt.xlabel("x")
plt.ylabel("y")
plt.grid()
plt.show()

In [None]:
# If solution available, plot the error in the trained y-values.
if p.analytical_solution is not None:
    y_analytical = p.analytical_solution(x_train)
    y_error = y_train - y_analytical
    rmse = np.sqrt(np.sum(y_error**2)/len(y_error))
    plt.plot(x_train, y_error)
    plt.title("Error in trained y(x)")
    plt.xlabel("x")
    plt.ylabel("$y_t - y_a$")
    plt.grid()
    plt.show()
    print("RMSE = %s" % rmse)

In [None]:
# Plot the validation y-values.
plt.plot(x_val, y_val)
plt.title("Validation y(x)")
plt.xlabel("x")
plt.ylabel("y")
plt.grid()
plt.show()

In [None]:
# If solution available, plot the error in the validation y-values.
if p.analytical_solution is not None:
    y_analytical = p.analytical_solution(x_val)
    y_error = y_val - y_analytical
    rmse = np.sqrt(np.sum(y_error**2)/len(y_error))
    plt.plot(x_val, y_error)
    plt.title("Error in validation y(x)")
    plt.xlabel("x")
    plt.ylabel("$y_v - y_a$")
    plt.grid()
    plt.show()
    print("RMSE = %s" % rmse)