In [1]:
"""
********************************************************************************
training
********************************************************************************
"""
import time
import yaml
import numpy as np 
import tensorflow as tf
import pickle
import numbers


from config_gpu import config_gpu
from pinn_base import PINN
from si_pinn import SI_PINN
from utils import (
    make_logger,
    write_logger,
    eval_dict,
    gen_condition,
    plot_loss_curve,
    plot_comparison,
    pdfs_to_gif,
    plot_comparison1d,
)

from tensorflow.python.framework.ops import disable_eager_execution

filename = "./settings/simple-sir.yaml"


with open(filename, mode="r") as file:
        settings = yaml.safe_load(file)
    # run hyperparameters args
logger_path = make_logger("seed: in model")
args = eval_dict(settings["ARGS"])

In [2]:
# ======model=======
model_args = eval_dict(settings["MODEL"], {"tf": tf, "": np})

for key in model_args.keys():
    if isinstance(model_args[key], list):
        model_args[key] = tf.constant(model_args[key], tf.float32)

var_names = settings["IN_VAR_NAMES"]
func_names = settings["OUT_VAR_NAMES"]
model = SI_PINN(var_names=var_names, func_names=func_names, **(model_args))
model.init_custom_vars(
    dict_consts=settings["CUSTOM_CONSTS"],
    dict_funcs=settings["CUSTOM_FUNCS"],
    var_names=var_names,
)

print("INIT DONE")

in_lb = model_args["in_lb"]
tmin = in_lb[0]
#xmin = in_lb[1]
in_ub = model_args["in_ub"]
tmax = in_ub[0]
# xmax = in_ub[1]
print(model_args)

  super().__init__(**kwargs)


INIT DONE
{'f_hid': 8, 'depth': 4, 'in_lb': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([-1.], dtype=float32)>, 'in_ub': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>, 'act': 'tanh', 'seed': 42, 'lr': 0.0001, 'beta': 0.01, 'dyn_norm': 'inv_dir'}


In [4]:
# ======conditions=======

conds = eval_dict(settings["CONDS"], locals() | {"tf": tf} | model.custom_vars, 1)
conditions = []
for key in list(conds.keys()):
    cond_ = gen_condition(
        conds[key], model_args, func_names=func_names, var_names=var_names, **model.custom_vars
    )
    conditions.append(cond_)
cond_string = [
    "self.loss_(*conditions[" + str(i) + "])," for i in range(len(conditions))
]
cond_string = compile("(" + "".join(cond_string) + ")", '<string>', 'eval')

model.init_dynamical_normalisation(len(conditions))


In [5]:
# ======outputs=======

ns = eval_dict(settings["NS"])
_x = [0] * len(var_names)
for i in range(len(var_names)):
    _x[i] = tf.linspace(in_lb[i], in_ub[i], ns["nx"][i])
_x = (tf.meshgrid(*_x))
print(ns, _x)

#x = tf.cast(np.empty((len(_x), int(np.prod(_x.shape)))), dtype=tf.float32)
x = [0]*len(_x)
for i in range(len(var_names)):
    x[i] = tf.reshape(_x[i],(-1,1)) #tf.cast(_x[i].reshape(-1, 1), dtype=tf.float32)
#x = tf.reshape(_x,()
x_ref = tf.transpose(tf.cast(x, dtype=tf.float32))[0]
u_ref = tf.cast(np.zeros(ns['nx']).reshape(-1, 1), dtype=tf.float32)
exact = tf.cast(model.custom_vars["exact"](x_ref), dtype=tf.float32)

# log
losses_logs = np.empty((len(conds.keys()), 1))

# training
wait = 0
loss_best = tf.constant(1e20)
loss_save = tf.constant(1e20)
t0 = time.perf_counter()

cond_string_here = [
   "model.loss_(*conditions[" + str(i) + "])," for i in range(len(conditions))
]
cond_string_here = "(" + "".join(cond_string_here) + ")"


{'nx': [500]} [<tf.Tensor: shape=(500,), dtype=float32, numpy=
array([-1.        , -0.995992  , -0.99198395, -0.98797596, -0.98396796,
       -0.9799599 , -0.9759519 , -0.9719439 , -0.96793586, -0.96392787,
       -0.9599198 , -0.9559118 , -0.9519038 , -0.94789577, -0.94388777,
       -0.9398798 , -0.9358717 , -0.9318637 , -0.92785573, -0.9238477 ,
       -0.9198397 , -0.9158317 , -0.91182363, -0.90781564, -0.90380764,
       -0.8997996 , -0.8957916 , -0.8917836 , -0.88777554, -0.88376755,
       -0.87975955, -0.8757515 , -0.8717435 , -0.8677355 , -0.86372745,
       -0.85971946, -0.8557114 , -0.8517034 , -0.8476954 , -0.84368736,
       -0.83967936, -0.8356713 , -0.8316633 , -0.8276553 , -0.82364726,
       -0.81963927, -0.8156313 , -0.8116232 , -0.8076152 , -0.8036072 ,
       -0.7995992 , -0.7955912 , -0.7915832 , -0.7875751 , -0.78356713,
       -0.77955914, -0.7755511 , -0.7715431 , -0.7675351 , -0.76352704,
       -0.75951904, -0.75551105, -0.751503  , -0.747495  , -0.743487  ,
 