Jax implementation of "A Sequential Meta-Transfer (SMT) Learning to Combat Complexities of Physics-Informed Neural Networks: Application to Composites Autoclave Processing"

Paper: https://arxiv.org/abs/2308.06447

In [1]:
# Import libraries
import jax
import numpy as onp
import jax.numpy as np
from jax import random, grad, vmap, jit

import optax
import flax.linen as nn

import math
import itertools
from tqdm import trange
from scipy.stats import qmc
import matplotlib.pyplot as plt

In [None]:
# System specification (Refer to the paper for more details: https://arxiv.org/abs/2308.06447)
T_max = 200 + 273 # Max temperature (k)
U_min = 0.001 # Initial degree of cure (doc)
U_diff = 1 - U_min
T_scaler = 1/T_max # Temperature scaler
T_min = (20 + 273) * T_scaler # Min temperature (k)
t_end = 18800 # Cure cycle duration (s)
t_raw = np.linspace(0, t_end, 5000)
t_scaler = 1/t_raw.max() # Time sclaer
t = t_raw * t_scaler

# 2-hold Cure cycle specification (scaled)
T_ini = (20 + 273) * T_scaler # Initial temperature (k)
ramp_rate_1 = (2/60) * (T_scaler/t_scaler) # Heat rate 1 (K/s)
T_hold_1 = (110 + 273) * T_scaler # Hold 1 temperature (k)
t_ramp_1_end = (T_hold_1 - T_ini)/ramp_rate_1 # When heat ramp 1 ends (s)
t_hold_1_end = 60*60*t_scaler + t_ramp_1_end # When hold 1 ends (s)
ramp_rate_2 = (2/60) * (T_scaler/t_scaler) # Heat rate 2 (K/s)
T_hold_2 = (180 + 273) * T_scaler # Hold 2 temperature (k)
t_ramp_2_end = (T_hold_2 - T_hold_1)/ramp_rate_2 + t_hold_1_end # When heat ramp 2 ends (s)
t_hold_2_end = 120*60*t_scaler + t_ramp_2_end # When hold 2 ends (s)
T_end = (20 + 273) * T_scaler
ramp_rate_3 = (3/60) * (T_scaler/t_scaler) # Cool down heat rate (K/s)
t_ramp_3_end = (T_hold_2 - T_end)/ramp_rate_3 + t_hold_2_end # When heat ramp 3 ends (s)

# Air temperature function
def Temp_air(t):
    return np.where(t < t_ramp_1_end, T_ini + t*ramp_rate_1, #  Ramp 1
             np.where(t < t_hold_1_end, T_hold_1, # Hold 1
                      np.where(t < t_ramp_2_end, T_hold_1 + (t - t_hold_1_end)*ramp_rate_2, # Ramp 2
                               np.where(t < t_hold_2_end, T_hold_2, # Hold 2
                                       T_hold_2 - (t - t_hold_2_end)*ramp_rate_3))) # Ramp 3
             )


T_boundary = Temp_air(t)
plt.plot(T_boundary)
T_mean = (T_boundary/T_scaler).mean() # C

# Composite part specifications
# Adapted from:
# 1) Andrew Johnson's thesis: https://open.library.ubc.ca/soa/cIRcle/collections/ubctheses/831/items/1.0088805
# 2) Niaki et al. https://www.sciencedirect.com/science/article/pii/S0045782521002966
# 3) Raven software: https://www.convergent.ca/products/raven-simulation-software

part_len = 0.03 #  part length (m)
part_nodes = 50
x_raw = np.linspace(0, part_len, part_nodes)
x_scaler = 1/x_raw.max() # Scaler
x = x_raw*x_scaler

x_lb, x_ub = x[1], x[-2]
t_lb, t_ub = t[1], t[-2]
lb = np.array([x_lb, t_lb])
ub = np.array([x_ub, t_ub])

X, T = np.meshgrid(x, t) # X: [T, X], T: [T, X]

X = X.flatten()[:, None]
T = T.flatten()[:, None]

F = np.asarray(np.hstack([X, T]))

# Model parameters and normalization
T_ave = T_mean # Average temperature for properties estimation (k)
alpha_ave = np.array(0.5) #  Average doc for properties estimation

# Fibre properties
rho_f = np.array(1.790e03) # fibre density (kg/m3), RAVEN Model
k_f = 2.4  + 1.560e-2 * (T_ave-273) # Fibre thermal conductivity (W/(m K)), RAVEN Model, transverse direction
Cp_f = 750 + 2.05 *  (T_ave-273-20)  # Fibre specific heat capacity (J/ (kg K)), RAVEN Model

# Resin properties
rho_r = np.array(1.300e3) # Resin density (kg/m3), RAVEN Model
k_r = 0.148 + 3.430E-04 * (T_ave-273) + 6.070E-02 * alpha_ave # Resin thermal conductivity (W/(m K)), RAVEN Model
Cp_r = 1005 + 3.74 *  (T_ave-273-20)  # resin specific heat capacity (J/ (kg K)), RAVEN Model
H_r = np.array(5.4e5) # Resin heat of reasction per unit mass (J / kg), RAVEN Model
nu_r = np.array(0.426) # Resin volume fraction in composite material (1-0.574)
h_c = np.array(120.0) # Heat trasnfer coefficient (W/(m2 K))

# Cure kinetics properties, 8552 epoxy resin
A = np.array(1.528e5)   # Pre-exponential cure rate coefficient (1/s)
dE = np.array(6.650e4)     #  Activation energy (J/mol)
M = np.array(0.8129) # First exponential constant
N = np.array(2.7360) # Second exponential constant
C = np.array(43.09) # Diffusion constant
ALCT = np.array(5.475e-3) #  Constant accounting for increase in critical resin degree of cure with temperature (1/K)
ALC = np.array(-1.6840) # Critical degree of cure at T = 0 K
R = np.array(8.314)     # Gas constant (J/(mol K))

# Composite part properties (fibre+resin)
nu_f = 1. - nu_r # fiber volume fraction
rho_c = rho_r * nu_r + rho_f * nu_f # Density (kg/m3)
Cp_c = Cp_r * nu_r + Cp_f * nu_f # Specific heat capacity (J/ (kg K))
BB = 2 * (k_r/k_f - 1) #eq (B.70) Andrew Johnston thesis
CC = (nu_f/math.pi)**0.5  #eq (B.70) Andrew Johnston thesis
DD = (1-(BB**2)*(CC**2))**0.5  #eq (B.70) Andrew Johnston thesis
k_c = k_r *( (1-2*CC) + 1 / BB * (math.pi-4/DD*math.atan(DD/(1+BB*CC))) )# Thermal conductivity (W/(m K)) eq (B.70) Andrew Johnston thesis
a_c =  k_c/(rho_c*Cp_c) # a in heat trasnfer PDE (m2 / s)
b =  rho_r*H_r*nu_r/(rho_c*Cp_c) # b in heat trasnfer PDE (K)
b = b/1.8255 # Correction

# Normalized properties
a_c_norm = a_c * ((x_scaler)**2)/t_scaler
b_norm = b * T_scaler
A_norm = A/t_scaler
dE_norm = dE * x_scaler**2/t_scaler**2
R_norm = R * (x_scaler**2)/(t_scaler**2) /T_scaler
ALCT_norm = ALCT /T_scaler

In [35]:
# Define neural network class
class Net(nn.Module):
    features: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=64, name='trainable1')(x)
        x = nn.tanh(x)
        x = nn.Dense(features=64, name='trainable2')(x)
        x = nn.tanh(x)
        x = nn.Dense(features=64, name='trainable3')(x)
        x = nn.tanh(x)
        x = nn.Dense(features=64, name='trainable4')(x)
        x = nn.tanh(x)
        x = nn.Dense(features=2, name='trainable5')(x)
        return x

In [36]:
# Define necessary functions to construct PINNs loss function

# Neural network with 2 output neurons (T and doc)
def neural_net(params, x, t):

    z = np.hstack([x, t])
    outputs = model.apply(params, z)
    return outputs[0]

# Function outputting corrected T neuron value
def neural_net_T(params, x, t):

    z = np.hstack([x, t])
    outputs = model.apply(params, z)
    T, a = np.split(outputs, 2)
    T_corrected = jax.nn.softplus(T[0]) + T_min

    return T_corrected

# Function outputting corrected doc neuron value
def neural_net_a(params, x, t):

    z = np.hstack([x, t])
    outputs = model.apply(params, z)
    T, a = np.split(outputs, 2)
    a_corrected = U_diff*jax.nn.sigmoid(a[0]) + U_min

    return a_corrected

# Cure kinetics (Andrew Johnson thesis)
def cure_kinetics(params, x, t):

    T = neural_net_T(params, x, t)
    alpha = neural_net_a(params, x, t)
    num = A_norm * np.exp(-dE_norm/(R_norm*T))
    den = 1. + np.exp(C*(alpha - (ALC + ALCT_norm*T)))
    cure_term = (num/den)*(alpha**M)*((1-alpha)**N)

    return cure_term

# ODE loss
def ode_net(params, x, t):

    cure_term = cure_kinetics(params, x, t)
    a_t = grad(neural_net_a, argnums=2)(params, x, t)
    f = a_t - cure_term

    return f

# PDE loss calculator
def pde_net(params, x, t):

    cure_term = cure_kinetics(params, x, t)
    T_t = grad(neural_net_T, argnums=2)(params, x, t)
    T_xx = grad(grad(neural_net_T, argnums = 1), argnums=1)(params, x, t)
    f = T_t - a_c_norm*T_xx - b_norm*cure_term

    return f

# Bottom BC loss calculator
def bcb_net(params, C_bot, x, t):

    u_bot = neural_net_T(params, x, t)
    T_air = Temp_air(t)
    T_x = grad(neural_net_T, argnums=1)(params, x, t)
    f = (T_air - u_bot) + C_bot*T_x

    return f

# Top BC loss calculator
def bct_net(params, C_top, x, t):

    u_top = neural_net_T(params, x, t)
    T_air = Temp_air(t)
    T_x = grad(neural_net_T, argnums=1)(params, x, t)
    f = (u_top - T_air) + C_top*T_x

    return f

# vmap the functions
u_pred_fn = vmap(neural_net_T, (None, 0, 0))
u_pred_fn_a = vmap(neural_net_a, (None, 0, 0))
p_pred_fn = vmap(pde_net, (None, 0, 0))
o_pred_fn = vmap(ode_net, (None, 0, 0))
cure_fn = vmap(cure_kinetics, (None, 0, 0))
bcb_fn = vmap(bcb_net, (None, None, 0, 0))
bct_fn = vmap(bct_net, (None, None, 0, 0))

# Evaluate the network and the residual over the grid
@jit
def loss_pde(params, F_pde_intv):

    p_pred = p_pred_fn(params, F_pde_intv[:,0], F_pde_intv[:,1])
    # Compute loss
    loss_p = np.mean(p_pred**2)

    return loss_p

@jit
def loss_ode(params, F_pde_intv):

    o_pred = o_pred_fn(params, F_pde_intv[:,0], F_pde_intv[:,1])
    # Compute loss
    loss_o = np.mean(o_pred**2)

    return loss_o

@jit
def loss_ics_T(params, F_ic_intv, U_T_ic_intv):
    # Evaluate the network over IC
    u_pred = vmap(neural_net_T, (None, 0, 0))(params, F_ic_intv[:,0], F_ic_intv[:,1])
    # Compute the initial loss
    loss_ics = np.mean((U_T_ic_intv.flatten() - u_pred.flatten())**2)

    return loss_ics

@jit
def loss_ics_a(params, F_ic_intv, U_a_ic_intv):
    # Evaluate the network over IC
    u_pred = vmap(neural_net_a, (None, 0, 0))(params, F_ic_intv[:,0], F_ic_intv[:,1])
    # Compute the initial loss
    loss_ics = np.mean((U_a_ic_intv.flatten() - u_pred.flatten())**2)

    return loss_ics

@jit
def loss_bcb(params, C_bot, F_bcb_intv):

    p_pred = bcb_fn(params, C_bot, F_bcb_intv[:,0], F_bcb_intv[:,1])
    loss_p = np.mean(p_pred**2)

    return loss_p

@jit
def loss_bct(params, C_top, F_bct_intv):

    p_pred = bct_fn(params, C_top, F_bct_intv[:,0], F_bct_intv[:,1])
    loss_p = np.mean(p_pred**2)

    return loss_p

# PINNs loss function
@jit
def loss(params, C_bot, C_top, F_pde_intv, F_ic_intv, F_bcb_intv, F_bct_intv, U_T_ic_intv, U_a_ic_intv, w_ic_T, w_ic_a, w_bcb, w_bct, w_pde):

    L_ic_T = loss_ics_T(params, F_ic_intv, U_T_ic_intv)
    L_ic_a = loss_ics_a(params, F_ic_intv, U_a_ic_intv)
    L_bcb = loss_bcb(params, C_bot, F_bcb_intv)
    L_bct = loss_bct(params, C_top, F_bct_intv)
    L_pde = loss_pde(params, F_pde_intv)
    L_ode = loss_ode(params, F_pde_intv)
    # Compute loss
    loss = w_ic_T*L_ic_T + w_ic_a*L_ic_a + w_bcb*L_bcb + w_bct*L_bct + w_pde*L_ode + w_pde*L_pde

    return loss

# Vmap the loss components
loss_ics_T_vmap = vmap(loss_ics_T, (None, None, 0))
loss_ics_a_vmap = vmap(loss_ics_a, (None, None, 0))
loss_bcb_vmap = vmap(loss_bcb, (None, 0, None))
loss_bct_vmap = vmap(loss_bct, (None, 0, None))
loss_vmap = vmap(loss, (None, 0, 0, None, None, None, None, 0, 0, None, None, None, None, None))

In [37]:
# Meta-learning (based on MAML formulation: https://arxiv.org/abs/1703.03400)

@jit
def update_params(params, grads, learning_rate = 0.00001):

  params = jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)

  return params

# Inner loop optimization (see MAML paper)
@jit
def step_inner(params, C_bot, C_top, F_pde_intv, F_ic_intv, F_bcb_intv, F_bct_intv, U_T_ic_intv, U_a_ic_intv, w_ic_T, w_ic_a, w_bcb, w_bct, w_pde,
               F_pde_test, F_ic_test, F_bcb_test, F_bct_test, U_T_ic_test, U_a_ic_test):

    grads = grad(loss, argnums=0)(params, C_bot, C_top, F_pde_intv, F_ic_intv, F_bcb_intv, F_bct_intv, U_T_ic_intv, U_a_ic_intv, w_ic_T, w_ic_a, w_bcb, w_bct, w_pde)
    params = update_params(params, grads)
    loss_value = loss(params, C_bot, C_top, F_pde_test, F_ic_test, F_bcb_test, F_bct_test, U_T_ic_test, U_a_ic_test, w_ic_T, w_ic_a, w_bcb, w_bct, w_pde)

    return loss_value

# Vmap the inner loop
step_inner_vmap = vmap(step_inner, (None, 0, 0, None, None, None, None, 0, 0, None, None, None, None, None, None, None, None, None, 0, 0))

# Calculate average loss among support set tasks
@jit
def step_inner_loss_mean(params, C_bot, C_top, F_pde_intv, F_ic_intv, F_bcb_intv, F_bct_intv, U_T_ic_intv, U_a_ic_intv, w_ic_T, w_ic_a, w_bcb, w_bct, w_pde,
               F_pde_test, F_ic_test, F_bcb_test, F_bct_test, U_T_ic_test, U_a_ic_test):

    loss_all = step_inner_vmap(params, C_bot, C_top, F_pde_intv, F_ic_intv, F_bcb_intv, F_bct_intv, U_T_ic_intv, U_a_ic_intv, w_ic_T, w_ic_a, w_bcb, w_bct, w_pde,
                   F_pde_test, F_ic_test, F_bcb_test, F_bct_test, U_T_ic_test, U_a_ic_test)

    return np.mean(loss_all)

# Outer loop optimization (see MAML paper)
@jit
def step_outter(params, C_bot, C_top, F_pde_intv, F_ic_intv, F_bcb_intv, F_bct_intv, U_T_ic_intv, U_a_ic_intv, w_ic_T, w_ic_a, w_bcb, w_bct, w_pde,
                opt_state_maml, F_pde_test, F_ic_test, F_bcb_test, F_bct_test, U_T_ic_test, U_a_ic_test):

    grads = grad(step_inner_loss_mean, argnums=0)(params, C_bot, C_top, F_pde_intv, F_ic_intv, F_bcb_intv, F_bct_intv, U_T_ic_intv, U_a_ic_intv, w_ic_T, w_ic_a, w_bcb, w_bct, w_pde,
                                        F_pde_test, F_ic_test, F_bcb_test, F_bct_test, U_T_ic_test, U_a_ic_test)
    updates, opt_state_maml = optimizer_maml.update(grads, opt_state_maml, params)
    params = optax.apply_updates(params, updates)

    return params, opt_state_maml

In [38]:
# Function to generate boundary, collocation and test points
def data_generator(t_b, t_e, n_bc, n_f):

  engine_bc = qmc.LatinHypercube(d=1)
  sample_bc = engine_bc.random(n=n_bc)
  t_interval = t_b + (t_e-t_b)*sample_bc
  F_bct = np.hstack([np.ones(t_interval.shape)*(x.max()), t_interval])
  F_bcb = np.hstack([np.ones(t_interval.shape)*(x.min()), t_interval])

  x_lb, x_ub = x[1], x[-2]
  t_lb, t_ub = t_b, t_e
  lb = np.array([x_lb, t_lb])
  ub = np.array([x_ub, t_ub])
  engine = qmc.LatinHypercube(d=2)
  sample = engine.random(n=n_f)
  F_pde = lb + (ub-lb)*sample

  F_intv = F[np.logical_and(F[:, 1]> t_b, F[:, 1]< t_e)]

  return F_bcb, F_bct, F_pde, F_intv

In [39]:
# Prepare training/test points for the time interval n for training the meta-learner
def training_data(time_b, delta, n_ic, n_bc, n_f, n_ic_test, n_bc_test, n_f_test):

  time_intv = time_b + delta

  # Training
  F_bcb_intv, F_bct_intv, F_pde_intv, F_intv = data_generator(time_b, time_intv, n_bc, n_f)
  ic_x_indx = np.linspace(0, 1, n_ic)
  ic_t_indx = np.ones(ic_x_indx.shape)*time_b
  F_ic_intv = np.hstack([ic_x_indx.reshape([-1,1]), ic_t_indx.reshape([-1,1])])

  # Test
  F_bcb_test, F_bct_test, F_pde_test, F_test = data_generator(time_b, time_intv, n_bc_test, n_f_test)
  ic_x_indx = np.linspace(0, 1, n_ic_test)
  ic_t_indx = np.ones(ic_x_indx.shape)*time_b
  F_ic_test = np.hstack([ic_x_indx.reshape([-1,1]), ic_t_indx.reshape([-1,1])])

  return F_ic_intv, F_bcb_intv, F_bct_intv, F_pde_intv, F_intv, F_ic_test, F_bcb_test, F_bct_test, F_pde_test, F_test

In [40]:
# Generate initial conditions for meta tasks
def meta_inits(n_task, time_b, T_min, U_min, F_ic_intv, F_ic_test, params_meta):

  U_T_ic_intv_meta = []
  U_a_ic_intv_meta = []
  U_T_ic_test_meta = []
  U_a_ic_test_meta = []

  for par_i in range(n_task):
    if time_b == 0.0:

      U_T_ic_intv = np.ones([F_ic_intv.shape[0], 1])*T_min
      U_a_ic_intv = np.ones([F_ic_intv.shape[0], 1])*U_min
      U_T_ic_test = np.ones([F_ic_test.shape[0], 1])*T_min
      U_a_ic_test = np.ones([F_ic_test.shape[0], 1])*U_min

    else:
      U_T_ic_intv = u_pred_fn(params_meta[par_i], F_ic_intv[:,0], F_ic_intv[:,1])
      U_a_ic_intv = u_pred_fn_a(params_meta[par_i], F_ic_intv[:,0], F_ic_intv[:,1])
      U_T_ic_test =u_pred_fn(params_meta[par_i], F_ic_test[:,0], F_ic_test[:,1])
      U_a_ic_test = u_pred_fn_a(params_meta[par_i], F_ic_test[:,0], F_ic_test[:,1])

    U_T_ic_intv_meta.append(U_T_ic_intv)
    U_a_ic_intv_meta.append(U_a_ic_intv)
    U_T_ic_test_meta.append(U_T_ic_test)
    U_a_ic_test_meta.append(U_a_ic_test)

  U_T_ic_intv_meta_ar = np.array(U_T_ic_intv_meta)
  U_a_ic_intv_meta_ar = np.array(U_a_ic_intv_meta)
  U_T_ic_test_meta_ar = np.array(U_T_ic_test_meta)
  U_a_ic_test_meta_ar = np.array(U_a_ic_test_meta)

  return U_T_ic_intv_meta_ar, U_a_ic_intv_meta_ar, U_T_ic_test_meta_ar, U_a_ic_test_meta_ar

In [41]:
# Meta learner training
def meta_train(nIter, params, opt_state_maml, shuffle = False, ignore_pde = False):

  # Ignore PDE/ODE loss terms for warming up (only for the first time interval)
  # See paper's remarks: https://arxiv.org/abs/2308.06447
  pbar = trange(10000)
  if ignore_pde:
    for it in pbar:
      # Set pde weights to 0
      params, opt_state_maml = step_outter(params, C_tasks[:,0],  C_tasks[:,1], F_pde_intv, F_ic_intv, F_bcb_intv, F_bct_intv,
                                    U_T_ic_intv_meta_ar, U_a_ic_intv_meta_ar, w_ic_T, w_ic_a, w_bcb, w_bct, 0.0,
                                    opt_state_maml, F_pde_test, F_ic_test, F_bcb_test, F_bct_test, U_T_ic_test_meta_ar, U_a_ic_test_meta_ar)

  pbar = trange(nIter)
  for it in pbar:

    params, opt_state_maml = step_outter(params, C_tasks[:,0],  C_tasks[:,1], F_pde_intv, F_ic_intv, F_bcb_intv, F_bct_intv,
                                        U_T_ic_intv_meta_ar, U_a_ic_intv_meta_ar, w_ic_T, w_ic_a, w_bcb, w_bct, w_pde,
                                        opt_state_maml, F_pde_test, F_ic_test, F_bcb_test, F_bct_test, U_T_ic_test_meta_ar, U_a_ic_test_meta_ar)

    # Evaluate progress on one of the support set's task
    if it % 100 == 0 and it > 0:
        loss_value = loss(params, C_tasks[0,0],  C_tasks[0,1], F_pde_test, F_ic_test, F_bcb_test, F_bct_test, U_T_ic_test_meta_ar[0],
                          U_a_ic_test_meta_ar[0], w_ic_T, w_ic_a, w_bcb, w_bct, w_pde)
        loss_ics_T_value = loss_ics_T(params, F_ic_test, U_T_ic_test_meta_ar[0])
        loss_ics_a_value = loss_ics_a(params, F_ic_test, U_a_ic_test_meta_ar[0])
        loss_bcb_value = loss_bcb(params, C_tasks[:,0], F_bcb_test)
        loss_bct_value = loss_bct(params, C_tasks[:,1],  F_bct_test)
        loss_pde_value = loss_pde(params, F_pde_test)
        loss_ode_value = loss_ode(params, F_pde_test)

        # Make shuffle True if you want to shuffle training points for training improvement
        #if shuffle:
        #  F_bcb_intv, F_bct_intv, F_pde_intv, _ = data_generator(time_b, time_intv, n_bc, n_f)
        #  F_bcb_test, F_bct_test, F_pde_test, _ = data_generator(time_b, time_intv, n_bc_test, n_f_test)

        pbar.set_postfix({'Loss': loss_value,
                  'loss_ics_T' : loss_ics_T_value,
                  'loss_ics_a' : loss_ics_a_value,
                  'loss_bcb' : loss_bcb_value,
                  'loss_bct' : loss_bct_value,
                  'loss_pde':  loss_pde_value,
                  'loss_ode': loss_ode_value,
                  'w_ic_T' : w_ic_T,
                  'w_ic_a' : w_ic_a,
                  'w_bcb' : w_bcb,
                  'w_bct' : w_bct,})

  return params


In [42]:
# Function to fine-tune meta-learner on each support set task to obtain the initial
# conditions for the next interval
nIter = 5000
def meta_task_training(params, params_meta, C_tasks, nIter):
# Training tasks individually to obtain the initial condition of the next time interval

  params_meta = []
  for i in range(n_task):
    params_sub = params

    C_bot, C_top = C_tasks[i,:]

    scheduler = optax.exponential_decay(init_value=1e-5, transition_steps=5000, decay_rate=0.9)
    optimizer = optax.adam(learning_rate=scheduler)
    opt_state = optimizer.init(params_sub)

    F_ic_test, F_bcb_test, F_bct_test, F_pde_test, F_test
    if time_b == 0.0:
      U_T_ic_intv = np.ones([F_ic_test.shape[0], 1])*T_min
      U_a_ic_intv = np.ones([F_ic_test.shape[0], 1])*U_min
    else:
      U_T_ic_test =u_pred_fn(params_meta[i], F_ic_test[:,0], F_ic_test[:,1])
      U_a_ic_test = u_pred_fn_a(params_meta[i], F_ic_test[:,0], F_ic_test[:,1])

    print('Training task: ', str(i))

    # Main training loop
    pbar = trange(nIter)
    for it in pbar:

        # Finetuning
        params_sub, opt_state = step_optax(params_sub, opt_state, C_bot, C_top, F_pde_test, F_ic_test, F_bcb_test, F_bct_test, U_T_ic_test, U_a_ic_test,
                                           w_ic_T, w_ic_a, w_bcb, w_bct, w_pde)

        if it % 50 == 0 and it > 0:
            loss_value = loss(params_sub, C_bot, C_top, F_pde_test, F_ic_test, F_bcb_test, F_bct_test, U_T_ic_test, U_a_ic_test, w_ic_T, w_ic_a, w_bcb, w_bct, w_pde)
            loss_ics_T_value = loss_ics_T(params_sub, F_ic_intv, U_T_ic_intv)
            loss_ics_a_value = loss_ics_a(params_sub, F_ic_intv, U_a_ic_intv)
            loss_bcb_value = loss_bcb(params_sub, C_bot, F_bcb_intv)
            loss_bct_value = loss_bct(params_sub,C_top,  F_bct_intv)
            loss_pde_value = loss_pde(params_sub, F_pde_intv)
            loss_ode_value = loss_ode(params_sub, F_pde_intv)

            pbar.set_postfix({'Loss': loss_value,
                      'loss_ics_T' : loss_ics_T_value,
                      'loss_ics_a' : loss_ics_a_value,
                      'loss_bcb' : loss_bcb_value,
                      'loss_bct' : loss_bct_value,
                      'loss_pde':  loss_pde_value,
                      'loss_ode': loss_ode_value,
                      'HTC_top' : C_top,
                      'HTC_bot' : C_bot})

    params_meta.append(params_sub)

  return params_meta

In [44]:
# Training requirements
# 1) initialize the neural network
rng = jax.random.PRNGKey(0)
x_init = jax.random.normal(rng, (20, 2))
model = Net(features=2)
params = model.init(rng, x_init)
out = model.apply(params, x_init)

# 2) Define the support set
# Randomly select tasks from the defined task distribution (here, [0.18, 0.58])
n_task = 20
lb_meta = np.array([0.18, 0.18])
ub_meta = np.array([0.58, 0.58])
engine_meta = qmc.LatinHypercube(d=2, seed = 42)
sample_meta = engine_meta.random(n=n_task)
C_tasks = lb_meta + (ub_meta-lb_meta)*sample_meta

# 3) Define weight hyperparameters
w_ic_T = 100 # Temperature initial condition weight
w_ic_a = 100 # DoC initial condition weight
w_bcb = 1 # Bottom boundary condition weight
w_bct = 1 # Top boundary condition weight
w_pde = 1 # PDE/ODE weight

# 4) Define time intervals (sequential learning)
time_b_list = [0.0, 0.1, 0.2, 0.3, 0.4, 0.45, 0.5, 0.55, 0.6, 0.7, 0.8, 0.9] # Intervals' starting point (normalized time domain)
delta_list = [0.1, 0.1, 0.1, 0.1, 0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1] # Intervals' duration

# 5) Training parameters
nIter = 100000
nIter_tasks = 100
n_ic, n_bc, n_f = 200, 200, 2000 # training IC, BC and collocation points count
n_ic_test, n_bc_test, n_f_test = 20, 20, 200 # test IC, BC and collocation points count
params_meta = [] # initialize meta-task parameters for time interval 1

# 6) Log
meta_params_log = []
meta_tasks_params_log = []

In [None]:
for i in range(len(time_b_list)):

  # Determine the time interval specifications
  time_b = time_b_list[i]
  delta = delta_list[i]
  print('---------------------------------')
  print('Time interval starts at:', time_b)
  print('Time interval duration:', delta)
  print('---------------------------------')
  print('Training meta-learner', i+1)

  # Specify training/test data (IC, BC and collocation) for meta learner
  F_ic_intv, F_bcb_intv, F_bct_intv, F_pde_intv, F_intv, F_ic_test, F_bcb_test, F_bct_test, F_pde_test, F_test = training_data(time_b, delta, n_ic, n_bc, n_f, n_ic_test, n_bc_test, n_f_test)

  # Initialize the support set tasks initial conditions (required for calculating the IC loss)
  U_T_ic_intv_meta_ar, U_a_ic_intv_meta_ar, U_T_ic_test_meta_ar, U_a_ic_test_meta_ar = meta_inits(n_task, time_b, T_min, U_min, F_ic_intv, F_ic_test, params_meta)

  # Define optimizer and decay rate for training the meta-learner
  scheduler = optax.exponential_decay(init_value=1e-5, transition_steps=5000, decay_rate=0.9)
  optimizer_maml = optax.adam(learning_rate=scheduler)
  opt_state_maml = optimizer_maml.init(params)

  # Training meta-learner for each time interval sequentially
  if i < 0.1:
    ignore_pde = True
  else:
    ignore_pde = False
  params = meta_train(nIter, params, opt_state_maml, shuffle = False, ignore_pde = ignore_pde)
  meta_params_log.append(params)

  print('Training supprot set tasks to obtain IC for interval', i+2)

  # Fine-tune meta-learner on each support set task to obtain the initial
  # conditions for the next interval
  params_meta = meta_task_training(params, params_meta, C_tasks, nIter_tasks)
  meta_tasks_params_log.append(params_meta)
