In [1]:
'''
simplified version of the inverse problem using linear interpolation
solve for (k,p,r) that minimize the difference between the predicted and the true temperature field
Author: Jiachen Guo, Northwestern University
'''
import numpy as onp
import jax
import jax.numpy as np
import time
import os,sys
import optax
from tqdm import trange

currentdir = os.getcwd()
parentdir = os.path.dirname(currentdir)
sys.path.append(parentdir)

from functools import partial
from src.generate_mesh import *



GPU_idx = 0
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPU_idx)
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
# jax.config.update("jax_enable_x64", True)

onp.set_printoptions(threshold=sys.maxsize, linewidth=1000, suppress=True, precision=4)

loaded_dict = np.load('./important_data_file/icml_spt.npz')
loaded_sol = {key: loaded_dict[key] for key in loaded_dict}

loaded_dictf = np.load('./important_data_file/icml_spt_inverse_truth.npz')
loaded_solf = {key: loaded_dictf[key] for key in loaded_dictf}
    
num_mode = 100
num_calibration_cases = 2

# Problem settings
r = 0.1

num_max_iter = 4
nelem_x = 100
nelem_y = 100
nelem_z = 100
nelem_t = 100


nelem_k = 100
nelem_p = 100
nelem_e = 100
nelem_d = 100
elem_type = 'D1LN2N' # 'D1LN2N'

Lx = 1 # length of the domain
Ly = 1 # length of the domain
Lz = 1 # length of the domain
Lt = 0.1 #time span

Lk = 3 #range for conductivity
Lp = 100.


Le = 0.2
Ld = 0.04
shift_x = 0
shift_y = 0
#%%

Gauss_Num_FEM = 2 # 6

nodes_per_elem = int(elem_type[4:-1]) #same for space and parameter for the same type of elements
dim = int(elem_type[1])
elem_dof = nodes_per_elem*dim    

# Problem setting

## Mesh generation
non_uniform_mesh_bool = False
#XY_host: nodal coord; #Elem_nodes_host: element and its nodal id
x, Elem_nodes_x = uniform_mesh_new(Lx, nelem_x)
y, Elem_nodes_y = uniform_mesh_new(Ly, nelem_y)
z, Elem_nodes_z = uniform_mesh_new(Lz, nelem_z)
t, Elem_nodes_t = uniform_mesh_new(Lt, nelem_t)
k, Elem_nodes_k = uniform_mesh_new(Lk, nelem_k)
p, Elem_nodes_p = uniform_mesh_new(Lp, nelem_p)

e, Elem_nodes_e = uniform_mesh_new(Le, nelem_e)
d, Elem_nodes_d = uniform_mesh_new(Ld, nelem_d)


nelem_kf = 1000
nelem_pf = 1000
kf, Elem_nodes_kf = uniform_mesh_new(Lk, nelem_kf)
pf, Elem_nodes_pf = uniform_mesh_new(Lp, nelem_pf)

x = x - shift_x #shift nodal value to avoid void parameters
y = y - shift_y #shift nodal value to avoid void parameters_host = XY_host - np.pi #shift nodal value to avoid void parameters
k = k + 1.0 #shift nodal value to avoid void parameters
p = p + 100. #shift nodal value to avoid void parameters

kf = kf + 1.0 #shift nodal value to avoid void parameters
pf = pf + 100. #shift nodal value to avoid void parameters

e = e + 0.10 #shift nodal value to avoid void parameters
d = d + 0.03 #shift nodal value to avoid void parameters

grids = {'x': x, 'y': y, 'z': z, 't': t, 'k': k, 'p': p}
gridsf = {'k': kf, 'p': pf}
#####XTD
dof_global_x = nelem_x + 1; dof_global_y = nelem_y + 1; dof_global_z = nelem_z + 1; dof_global_t = nelem_t + 1
dof_global_k = nelem_k + 1; dof_global_p = nelem_p + 1
print(f"\n-----FEM elem_type: {elem_type}, DOFs_x: {dof_global_x}, DOFs_y: {dof_global_y}, DOFs_z: {dof_global_z}, DOFs_t: {dof_global_t}, DOFs_k: {dof_global_k}, DOFs_p: {dof_global_p}----")

U_x, U_y, U_z, U_t = loaded_sol['x'], loaded_sol['y'], loaded_sol['z'], loaded_sol['t']

U_k, U_p= loaded_sol['k'], loaded_sol['p']
t_idx = np.arange(0, nelem_t, 10) #outputing every 10th time node
print(f"t_idx: {t_idx}")
U_t_select = U_t[:, t_idx]



-----FEM elem_type: D1LN2N, DOFs_x: 101, DOFs_y: 101, DOFs_z: 101, DOFs_t: 101, DOFs_k: 101, DOFs_p: 101----
t_idx: [ 0 10 20 30 40 50 60 70 80 90]


## ground truth measurement database

In [16]:
U_xf, U_yf, U_zf, U_tf = loaded_solf['x'], loaded_solf['y'], loaded_solf['z'], loaded_solf['t']

U_kf, U_pf= loaded_solf['k'], loaded_solf['p']
t_idxf = np.arange(0, 1000, 100) #outputing every 10th time node
print(f"t_idx: {t_idxf}")
U_tf_select = U_tf[:, t_idxf]


@jax.jit
def td_ground_truth(x_input, sol_dict, grids):

    interval = 40
    
    k0 = x_input['k0']; p0 = x_input['p0']
    grid_k = grids['k'].reshape(-1); grid_p = grids['p'].reshape(-1)
    
    
    k_interpolated = np.array([np.interp(k0, grid_k, U_kf[i, :]) for i in range(U_k.shape[0])])
    p_interpolated = np.array([np.interp(p0, grid_p, U_pf[i, :]) for i in range(U_p.shape[0])])
    # print(f"kf: {kf[::interval]}")
    U = np.zeros((U_x.shape[1], U_y.shape[1], U_z.shape[1], t_idx.shape[0]))
    
    # print(U_t_select.shape)
    # for i in range(U_x.shape[0]):
    #     U += (U_xf[i, ::2, None, None, None] * U_yf[i, None, ::2, None, None] * U_zf[i, None, None, ::2, None] * \
    #           U_tf_select[i, None, None, None, :] * k_interpolated[i, None, None, None, None]* p_interpolated[i, None, None, None, None])
    
    U = np.sum(U_xf[:, ::interval, None, None, None] * U_yf[:, None, ::interval, None, None] * U_zf[:, None, None, ::interval, None] * \
              U_tf_select[:, None, None, None, :] * k_interpolated[:, None, None, None, None] * p_interpolated[:, None, None, None, None], axis = 0)
    
    # print(f"before reshape: {U.shape}")
    U = U.transpose(3, 0, 1, 2)
    print(U.shape)
    return U, k0, p0

t_idx: [  0 100 200 300 400 500 600 700 800 900]


## define loss fun

In [17]:
@jax.jit                
def forward(x_input, sol_dict, grids):
    k0 = x_input['k0']; p0 = x_input['p0']
    grid_k = grids['k'].reshape(-1); grid_p = grids['p'].reshape(-1)
    
    U = np.zeros((U_x.shape[1], U_y.shape[1], U_z.shape[1], t_idx.shape[0]))
    
    interval = 4
    # print(f"k: {k[::interval]}")
    # print(U_t_select.shape)
    # for i in range(U_x.shape[0]):
    # # Compute the outer products for each time node incrementally
    #     U += (U_x[i, ::2, None, None, None] * U_y[i, None, ::2, None, None] * U_z[i, None, None, ::2, None] * \
    #           U_t_select[i, None, None, None, :] * np.interp(k0, grid_k, U_k[i, :]) * np.interp(p0, grid_p, U_p[i, :]))
    k_interpolated = np.array([np.interp(k0, grid_k, U_k[i, :]) for i in range(U_k.shape[0])])
    p_interpolated = np.array([np.interp(p0, grid_p, U_p[i, :]) for i in range(U_p.shape[0])])
    
    U = np.sum(U_x[:, ::interval, None, None, None] * U_y[:, None, ::interval, None, None] * U_z[:, None, None, ::interval, None] * \
                U_t_select[:, None, None, None, :] * k_interpolated[:, None, None, None, None] *  p_interpolated[:, None, None, None, None], axis = 0)
    # print(f"before reshape: {U.shape}")
    U = U.transpose(3, 0, 1, 2)
    print(U.shape)
    return U



 

@jax.jit
def loss_fun(x_input, U_true):   
    return np.linalg.norm(forward(x_input, loaded_sol, grids) - U_true) #dim (num_mode)      

@jax.jit
def relative_L2_norm_error(x_input, U_true):   
    return np.linalg.norm(forward(x_input, loaded_sol, grids) - U_true)/np.linalg.norm(U_true) #dim (num_mode) 

@jax.jit
def grad_fun(x_input, U_true):
    grads = jax.grad(loss_fun)(x_input, U_true)
    return grads


def inv_opt(params, optimizer):
    opt_state = optimizer.init(params)
    
    lower_bounds = {
    'k0': np.min(grids['k']),
    'p0': np.min(grids['p']),
    }

    upper_bounds = {
        'k0': np.max(grids['k']),
        'p0': np.max(grids['p']),
    }
    
    @jax.jit
    def step(params, opt_state, U_true):
        loss_val = loss_fun(params, U_true)
        error  = relative_L2_norm_error(params, U_true)
        grads = grad_fun(params, U_true)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        params = optax.projections.projection_box(params, lower_bounds, upper_bounds)
        return params, opt_state, loss_val, error

    for epoch in range(NUM_STEPS):
        params, opt_state, loss_value, error = step(params, opt_state, U_true)
        # if epoch % 1 == 0:
        #     print(f'step {epoch}, loss: {loss_value}, params: {params}')
        if error < 1e-5:
            break
    print(f'step {epoch}, loss: {loss_value}, error: {error}')
    return params, loss_value, error


## running calibration with random selected input parameters
### hyperparameters for the optimizer: LEARNING_RATE, NUM_STEPS

In [19]:
loss_list = []
error_L2_list = []
error_param_p_list = []
error_param_k_list = []
for i in trange(num_calibration_cases):

    x_ground_truth = {'k0': onp.random.uniform(k[0], k[-1])[0], 'p0': onp.random.uniform(p[0], p[-1])[0]}
    U_true, k_true, p_true = td_ground_truth(x_ground_truth, loaded_solf, gridsf)
    print(f"U_true: {U_true.shape}")
    x_initial = {'k0': onp.random.uniform(k[0], k[-1])[0], 'p0': onp.random.uniform(p[0], p[-1])[0]}

    LEARNING_RATE = 1e-1
    NUM_STEPS = 3000  
    optimizer = optax.adam(LEARNING_RATE)
    x_opt, loss_value, error = inv_opt(x_initial, optimizer)

    loss_list.append(loss_value)
    error_L2_list.append(error)
    error_param_p_list.append(np.abs(p_true - x_opt['p0'])/p_true)
    error_param_k_list.append(np.abs(k_true - x_opt['k0'])/k_true)
    
    # print(f"Final loss is {loss_fun(x_opt, U_true)}")
    # print(f"Final grad is {grad_fun(x_opt, U_true)}")
    print(f"Final x_input is {x_opt}")
    print(f'k_true: {k_true}, p_true: {p_true}')

print(f"loss_mean: {np.mean(np.array(loss_list))}, loss_std: {np.std(np.array(loss_list))}")
print(f"error_L2_mean: {np.mean(np.array(error_L2_list))}, error_std: {np.std(np.array(error_L2_list))}")
print(f"error_param_k_mean: {np.mean(np.array(error_param_k_list))}, error_std: {np.std(np.array(error_param_k_list))}")
print(f"error_param_p_mean: {np.mean(np.array(error_param_p_list))}, error_std: {np.std(np.array(error_param_p_list))}")


  0%|          | 0/2 [00:00<?, ?it/s]

(10, 26, 26, 26)
U_true: (10, 26, 26, 26)
(10, 26, 26, 26)


 50%|█████     | 1/2 [00:23<00:23, 23.67s/it]

step 2999, loss: 33.59607696533203, error: 0.001964666647836566
Final x_input is {'k0': Array(3.2099, dtype=float32), 'p0': Array(119.6181, dtype=float32)}
k_true: 3.2208926677703857, p_true: 119.90022277832031
U_true: (10, 26, 26, 26)


100%|██████████| 2/2 [00:41<00:00, 20.60s/it]

step 2999, loss: 55.398216247558594, error: 0.0017087507294490933
Final x_input is {'k0': Array(2.5473, dtype=float32), 'p0': Array(194.8276, dtype=float32)}
k_true: 2.545858860015869, p_true: 194.4252471923828
loss_mean: 44.49714660644531, loss_std: 10.901069641113281
error_L2_mean: 0.0018367087468504906, error_std: 0.00012795795919373631
error_param_k_mean: 0.0019813396502286196, error_std: 0.0014316167216748
error_param_p_mean: 0.002210971899330616, error_std: 0.00014172808732837439



