In [None]:
import sys
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
file_name = 'TEST_ORIGINAL.mat'
dir_path = os.path.dirname(os.path.realpath(file_name))
sys.path.append(dir_path)
from sklearn.model_selection import train_test_split
import argparse
import numpy as np
os.environ["DDE_BACKEND"] = "pytorch"
import deepxde as dde # version 0.11 or higher
import utils
import pinn
from generate_plots import plot_variable, plot_loss, plot_phie, plot_multiple_phies
import matplotlib.pyplot as plt
import torch
torch.cuda.empty_cache()
print("CUDA Available:", torch.cuda.is_available())

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--file-name', dest='file_name', required = True, type = str, help='File name for input data')
    parser.add_argument('-m', '--model-folder-name', dest='model_folder_name', required = False, type = str, help='Folder name to save model (prefix /)')
    parser.add_argument('-d', '--dimension', dest='dim', required = True, type = int, help='Model dimension. Needs to match the input data')
    parser.add_argument('-n', '--noise', dest='noise', action='store_true', help='Add noise to the data')
    parser.add_argument('-w', '--w-input', dest='w_input', action='store_true', help='Add W to the model input data')
    parser.add_argument('-v', '--inverse', dest='inverse', required = False, type = str, help='Solve the inverse problem, specify variables to predict (e.g. a / ad / abd')
    parser.add_argument('-ht', '--heter', dest='heter', required = False, action='store_true', help='Predict heterogeneity - only in 2D')    
    parser.add_argument('-p', '--plot', dest='plot', required = False, action='store_true', help='Create and save plots')
    parser.add_argument('-a', '--animation', dest='animation', required = False, action='store_true', help='Create and save 2D Animation')
    args = parser.parse_args()

# Other parameters
noise_factor = 1
test_size = 0.75

In [None]:
def main(args):

    # System dynamics
    dynamics = utils.system_dynamics()
    params = dynamics.params_to_inverse(inverse)
    print(params)

    # Load data
    observe_x, V, W, phie, observe_elec = dynamics.generate_data(file_name, dim)

    # Split data
    observe_train, observe_test, v_train, v_test, w_train, w_test = train_test_split(
        observe_x, V, W, test_size=test_size
    )
    elec_train, elec_test, phie_train, phie_test = train_test_split(
        observe_elec, phie, test_size=test_size
    )

    # Add noise if needed
    if args.noise:
        #v_train += noise_factor * np.random.randn(*v_train.shape)
        phie_train += noise_factor * np.random.randn(*phie_train.shape)

    geomtime = dynamics.geometry_time(args.dim)
    bc = dynamics.BC_func(args.dim, geomtime)
    ic = dynamics.IC_func(observe_train, v_train)
    ic2 = dynamics.IC_func(elec_train, phie_train)

    observe_phie = dde.PointSetBC(elec_train, phie_train, component=2)
    input_data = [bc,ic,ic2,observe_phie]
    if args.w_input:
        input_data.append(dde.PointSetBC(observe_train, w_train, component=1))

    #Define model
    torch.cuda.empty_cache()
    model_pinn = pinn.PINN(dynamics, args.dim, args.heter, args.inverse)
    model_pinn.define_pinn(geomtime, input_data, observe_train)

    # Train model
    out_path = dir_path + args.model_folder_name
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Training on device:", device)
    model, losshistory, train_state = model_pinn.train(out_path, params)
    dde.utils.external.plot_loss_history(losshistory, 'losshistory.png')
    print('Final loss weights:',model.loss_weights)

    # Predict V
    pred = model.predict(observe_test)
    v_pred = pred[:, 0:1]

    # Predict Phie
    pred = model.predict(observe_elec)
    phie_pred = pred[:, 2:3]
    phie_pred = dynamics.arrange_phie(observe_elec, phie_pred)

    true_phie = phie.reshape(dynamics.nelec, int(dynamics.max_t))
    plot_phie(dynamics,phie_pred, true_phie, 'true_vs_pred.png')

    # Compute rMSE
    rmse_v = np.sqrt(np.square(v_pred - v_test).mean())
    rmse_phie = np.sqrt(np.square(phie_pred - true_phie).mean())
    print("---------------------------------")
    print("V rMSE:", rmse_v)
    print("Phie rMSE:", rmse_phie)
    print("---------------------------------")
    if params:
        print("Estimated parameters:",params)
        print("---------------------------------")

model = main(args)