In [None]:
# Import TensorFlow and NumPy
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from PINNs_Chron import PINN3D, PINN3DSolver

# Set data type
DTYPE='float32'
tf.keras.backend.set_floatx(DTYPE)
tf.get_logger().setLevel('ERROR')

In [None]:
%matplotlib inline

SIZE = 12
BIGGER_SIZE = 16

plt.rc('font', family='Arial', size=SIZE) # controls default text sizes
plt.rc('axes', titlesize=SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=10)    # legend fontsize

In [None]:
# 40 Ma model time is 110 Ma geo time
uplift = lambda t : np.where(t < 40, .6, .05)
tf_uplift = lambda t: tf.where(t < 40, .6, .05)
tf_uplift_inv = lambda t, u1, t1: tf.where(t < t1, .6, u1)

In [None]:
topo_ar = np.loadtxt('Dabie/dabie_utm_km.xyz', dtype=DTYPE)
topo_x, topo_x_pos = np.unique(topo_ar[:, 0], return_inverse=True)
topo_y, topo_y_pos = np.unique(topo_ar[:, 1], return_inverse=True)
topo_pivot = np.zeros((len(topo_x), len(topo_y)), dtype=DTYPE)
topo_pivot[topo_x_pos, topo_y_pos] = topo_ar[:, 2]
topo_x_min, topo_x_max = topo_x.min(), topo_x.max()
topo_y_min, topo_y_max = topo_y.min(), topo_y.max()

In [None]:
t_end = 150.
h0 = 30.
xl = topo_x.max() - topo_x.min()
yl = topo_y.max() - topo_y.min()
Tsea = 15.
Tbot = 600.
kappa = 25.
air_lapse = 5.

# Set number of data points
N_0 = 2000
N_b_b = 2000
N_b_s = 5000
N_b = N_b_b + N_b_s
N_r = 50000

# Set boundary
tmin = 0.
tmax = t_end
zmin = 0.
zmax = h0 + topo_ar[:, 2].max() * 6.
xmin = 0.
xmax = xl
ymin = 0.
ymax = yl

# Lower bounds
lb = tf.constant([tmin, xmin, ymin, zmin], dtype=DTYPE)
# Upper bounds
ub = tf.constant([tmax, xmax, ymax, zmax], dtype=DTYPE)

# Set random seed
tf.random.set_seed(128)

In [None]:
import tensorflow_probability as tfp

def tf_h_fn_inv(t, x, y, amp0, t_decay=70.):
    amp0 = tf.cast(amp0, dtype=DTYPE)
    one = tf.ones(1)
    amp = tf.where(t < t_decay, amp0, one + (amp0 - one) * (t_end - t)/ (t_end - t_decay))
    xy = tf.cast(tf.concat([x, y], axis=1), dtype=DTYPE)
    h_interp = tfp.math.batch_interp_regular_nd_grid(xy, [topo_x_min, topo_y_min,], [topo_x_max, topo_y_max], topo_pivot, axis=-2)
    return h0 + amp * tf.reshape(h_interp, tf.shape(x))

def tf_h_fn_init(x, y, amp0):
    zero = tf.zeros(1)
    one = tf.ones(1)
    return tf_h_fn_inv(zero, x, y, amp0, one)

def T_init(x, y, z, amp0):
    return Tbot - T_grad0(x, y, amp0) * z

# Define boundary condition
def T_surf(surf_z, sealevel_z=h0, lapse=air_lapse):
    elevation = surf_z - sealevel_z
    return tf.ones(tf.shape(surf_z), dtype=DTYPE) * Tsea - elevation * lapse

def T_bot(z):
    return tf.constant(Tbot, shape=z.shape, dtype=DTYPE)
    
# T_grad_inv = lambda t, x, y, amp0, t_decay: (Tbot - T_surf(tf_h_fn_inv(t, x, y, amp0, t_decay=70.))) / tf_h_fn_inv(t, x, y, amp0, t_decay=70.)
T_grad_inv = lambda t, x, y, amp0: (Tbot - T_surf(tf_h_fn_inv(t, x, y, amp0, t_decay=70.))) / tf_h_fn_inv(t, x, y, amp0, t_decay=70.)
T_grad0 = lambda x, y, amp0: (Tbot - T_surf(tf_h_fn_init(x, y, amp0))) / tf_h_fn_init(x, y, amp0)

In [None]:
def generate_data(amp0, t_decay=70.):    
    # Draw uniform sample points for initial boundary data
    t_0 = tf.ones((N_0, 1), dtype=DTYPE) * lb[0]
    x_0 = tf.random.uniform((N_0, 1), lb[1], ub[1], dtype=DTYPE)
    y_0 = tf.random.uniform((N_0, 1), lb[2], ub[2], dtype=DTYPE)
    z_0 = tf.random.uniform((N_0, 1), lb[3], tf_h_fn_init(x_0, y_0, amp0), dtype=DTYPE)
    X_0 = tf.concat([t_0, x_0, y_0, z_0], axis=1)
    # Evaluate intitial condition at z_0
    T_0 = T_init(x_0, y_0, z_0, amp0)

    # Boundary data
    mid_t = lb[0]+(ub[0]-lb[0])*.3 #to separate t space into two and sample each
    t_b0 = tf.random.uniform((N_b_b//2, 1), lb[0], mid_t, dtype=DTYPE)
    t_b1 = tf.random.uniform((N_b_b//2, 1), mid_t, ub[0], dtype=DTYPE)
    t_b = tf.concat([t_b0, t_b1], axis=0)

    x_b = tf.random.uniform((N_b_b, 1), lb[1], ub[1], dtype=DTYPE)
    y_b = tf.random.uniform((N_b_b, 1), lb[2], ub[2], dtype=DTYPE)
    z_b = lb[3] * tf.ones((N_b_b, 1), dtype=DTYPE)
    
    t_s0 = tf.random.uniform((N_b_s//2, 1), lb[0], mid_t, dtype=DTYPE)
    t_s1 = tf.random.uniform((N_b_s//2, 1), mid_t, ub[0], dtype=DTYPE)
    t_s = tf.concat([t_s0, t_s1], axis=0)

    x_s = tf.random.uniform((N_b_s, 1), lb[1], ub[1], dtype=DTYPE)
    y_s = tf.random.uniform((N_b_s, 1), lb[2], ub[2], dtype=DTYPE)
    z_s = tf.constant(tf_h_fn_inv(t_s, x_s, y_s, amp0, t_decay), dtype=DTYPE)

    t_bs = tf.concat([t_b, t_s], axis=0)
    x_bs = tf.concat([x_b, x_s], axis=0)
    y_bs = tf.concat([y_b, y_s], axis=0)
    z_bs = tf.concat([z_b, z_s], axis=0)
    X_bs = tf.concat([t_bs, x_bs, y_bs, z_bs], axis=1)

    # Evaluate boundary condition
    T_b = T_bot(z_b)
    T_s = T_surf(z_s)
    T_bs = tf.concat([T_b, T_s], axis=0)

    # Collect boundary and inital data in lists
    X_data = tf.concat([X_0, X_bs], axis=0)
    T_data = tf.concat([T_0, T_bs], axis=0)
    
    return X_data, T_data

In [None]:
class PINNInv3D(PINN3D):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.u0 = tf.constant(.6)
        self.u1 = tf.constant(1.)
        self.t1 = tf.constant(1.)
        self.amp0 = tf.constant(1.)
        
        self.transform = tf.keras.layers.Lambda(
            lambda x: x[:, 0:1] * (
                self.Tbot - T_grad_inv(x[:, 0:1], x[:, 1:2], x[:, 2:3], self.amp0) * x[:, 3:4])   
                )

In [None]:
from scipy.optimize import differential_evolution as diff_evo
class InvSolver(PINN3DSolver):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.age_loss_hist = []
        self.u1_hist = []
        self.t1_hist = []
        self.amp0_hist = []
        self.it_hist = []
        #the following needs to be set up for optimization
        self.data_age = tf.constant([])
        self.data_method = []
        self.data_radi = []
        self.data_err = tf.constant([])
        self.data_x = []
        self.data_y = []
        
    def fun_u(self, t):
        return tf_uplift_inv(t, self.model.u1, self.model.t1)
    
    def fun_h(self, t, x0, y0):
        return tf_h_fn_inv(t, [[x0]], [[y0]], self.model.amp0)[0]
            
    def age_loss(self):
        age_pred = []
        for hx, hy, hm, hr in zip(self.data_x, self.data_y, self.data_method, self.data_radi):
            sample_t, sample_T = self.get_tT(hx, hy)
            age_pred.append(self.pred_age(sample_t, sample_T, method=hm, grain_radius=hr))
            
        return tf.reduce_mean(tf.square((age_pred - self.data_age)/self.data_err))
            
    #change the following function to replace X, T with generate_data function (amp0, t_decay?)
    def solve_with_DE(self, TF_optimizer, bounds, adam_n=100,
                      echofreq=1000, savefreq=1000,
                      strategy='best2bin', init='sobol', disp=True,
                      mutation=(0.5, 1), recombination=0.7,
                      maxiter=100, popsize=30, polish=False):
    
        def DE_obj_fn(x):
            u1, t1, amp0 = x[0], x[1], x[2]
            self.model.u1 = tf.constant(u1, dtype=DTYPE)
            self.model.t1 = tf.constant(t1, dtype=DTYPE)
            self.model.amp0 = tf.constant(amp0, dtype=DTYPE)
            X, T = generate_data(self.model.amp0)
            self.solve_with_Adam(TF_optimizer, X, T, N=adam_n, echofreq=echofreq, savefreq=savefreq)
            
            age_loss = self.age_loss()
            self.u1_hist.append(u1)
            self.t1_hist.append(t1)
            self.amp0_hist.append(amp0)
            self.it_hist.append(self.iter)
            self.current_age_loss = age_loss.numpy()
            self.age_loss_hist.append(self.current_age_loss)
            return age_loss

        def DEcallback(x, convergence=0.):
            print('Best solution: u1 = {:5.4g} t1 = {:6.4g} amp0 = {:5.4g}'.format(
                                  x[0], x[1], x[2]))
            return solver.current_loss < 10 and solver.current_age_loss < .1
    

        return diff_evo(DE_obj_fn, bounds, strategy=strategy, init=init,
                   disp=disp, maxiter=maxiter, popsize=popsize, polish=polish,
                    mutation=mutation, recombination=recombination,
                   callback=DEcallback)

In [None]:
# Draw uniformly sampled collocation points
mid_t = lb[0]+(ub[0]-lb[0])*.3
t_r0 = tf.random.uniform((N_r//2, 1), lb[0], mid_t, dtype=DTYPE)
t_r1 = tf.random.uniform((N_r//2, 1), mid_t, ub[0], dtype=DTYPE)
t_r = tf.concat([t_r0, t_r1], axis=0)
# t_r = tf.random.uniform((N_r, 1), lb[0], ub[0], dtype=DTYPE)
x_r = tf.random.uniform((N_r, 1), lb[1], ub[1], dtype=DTYPE)
y_r = tf.random.uniform((N_r, 1), lb[2], ub[2], dtype=DTYPE)
z_r = tf.random.uniform((N_r, 1), lb[3], tf_h_fn_init(x_r, y_r, 1.), dtype=DTYPE)
X_r = tf.concat([t_r, x_r, y_r, z_r], axis=1)

In [None]:
def read_ages(txt, DYTPE='float32'):
    ar = np.loadtxt(txt, dtype='str')
    x = ar[:, 0].astype(DTYPE)
    y = ar[:, 1].astype(DTYPE)
    ages = tf.constant(ar[:, 2].astype(DTYPE))
    errs = tf.constant(ar[:, 3].astype(DTYPE))
    radi = ar[:, 4].astype(DTYPE)
    methods = ar[:, 5]
    return x, y, ages, errs, radi, methods

In [None]:
# Initialize model
model = PINNInv3D(lb, ub, num_hidden_layers=6, num_neurons_per_layer=20)
model.build(input_shape=(None, 4))

lr = 1e-3
# lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay([20000],[1e-2, 1e-3])
optim = tf.keras.optimizers.Adam(learning_rate=lr)

# Initilize PINN solver
solver = InvSolver(model, X_r)
solver.savepath = 'saved_inv_model3d'

solver.data_x, solver.data_y, solver.data_age, solver.data_err, solver.data_radi, solver.data_method = read_ages('dabie_ages.txt')

In [None]:
from time import time
# Start timer
tic = time()

bounds = [(0., .1), (0., 100.), (0., 6.)]

adam_n, de_popsize, de_maxit = 50, 10, 200

res = solver.solve_with_DE(optim, bounds, mutation=.5, recombination=.5, echofreq=100, savefreq=100,
                           adam_n=adam_n, popsize=de_popsize, maxiter=de_maxit)

# Print computation time
print('\nComputation time: {} seconds'.format(time()-tic));

In [None]:
fig, ax = plt.subplots(figsize=(4, 3))
ax.plot(solver.loss_hist, label='$\phi_{total}$')
ax.plot(solver.heat_hist, label='$\phi_{heat}$')
ax.plot(solver.bound_hist, label='$\phi_{ibc}$')
ax.set_yscale('log')
ax.set_yticks(np.logspace(0, 5, 6))
ax.legend(loc='upper right')
ax.set_xlabel('Iteration')
ax.set_ylabel('Loss')

fig.tight_layout()
# fig.savefig('plots/3dinv_loss.pdf')

In [None]:
runpath = solver.savepath + '/'+ solver.runname + '/'
np.save(runpath + 'it_hist', solver.it_hist)
np.save(runpath + 'loss_hist', solver.loss_hist)
np.save(runpath + 'age_loss_hist', solver.age_loss_hist)
np.save(runpath + 'X_r', X_r)
np.save(runpath + 'u1_hist', solver.u1_hist)
np.save(runpath + 't1_hist', solver.t1_hist)
# np.save(runpath + 't_d_hist', solver.t_d_hist)
np.save(runpath + 'amp0_hist', solver.amp0_hist)
np.save(runpath + 'adam_m', adam_n)
np.save(runpath + 'de_popsize', de_popsize)
np.save(runpath + 'de_maxit', de_maxit)