# Learning of a Physics-informed DeepONet for quantum graphs

In this notebook, we want to learn a *DeepONet* for the advection-diffusion equation considered in our paper.
The codes use the github repository accompanying the paper Physics-informed DeepONets as templates.

## Import necessary packages

In [None]:
#import numpy as np
import matplotlib.pyplot as plt
%matplotlib widget
import re
import numpy as np
import jax.numpy as jnp
from jax import random, grad, vmap, jit, config
import jax
from tqdm import trange

from functools import partial
from jax.example_libraries import optimizers
from jax.flatten_util import ravel_pytree
import itertools
from jax import device_put
from jax.example_libraries import optimizers
import pickle

import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
import sys

# Import one example to get PDE (it's the same for all examples)
from src.graph import Example0


## Set important parameters and model to be trained

In [None]:
EPS = 0.1

# Specify path of data
DATA_PATH = f'./PI_DeepONet_Data/variable_velocity_eps{EPS}_n10_nx200_nt1000/'

# Specify the model which should be trained
# N_WIDTH ... width of the hidden layers
# N_SPLIT ... how often the data is split into batches (should be adapted to available hardware)
# mode ... number of training samples that are used (0 - 5K samples, 1 - 10K samples, 2 - 20K samples, -1 for testing)

N_WIDTH, N_SPLIT, mode = 100, 1, -1
#N_WIDTH, N_SPLIT, mode = 100, 2, 1
#N_WIDTH, N_SPLIT, mode = 100, 4, 2
#N_WIDTH, N_SPLIT, mode = 200, 2, 0
#N_WIDTH, N_SPLIT, mode = 200, 4, 1
#N_WIDTH, N_SPLIT, mode = 200, 8, 2
#N_WIDTH, N_SPLIT, mode = 100, 8, 0

# Number of training epochs
N_EPOCHS = 20000

# Numebr of residual terms per sample used to evaluate pinn loss
N_RES_PER_SAMPLE = 101

# The model which should be trained, in order to run the code you have to do this for all kinds of edges
train_model = 'inner'
# train_model = 'inflow'
# train_model = 'outflow'

# Path were the final parameters should be stored
PARAM_PATH = f'./final_params/params_eps_{EPS}_{train_model}_mode_{mode}_width_{N_WIDTH}_split_{N_SPLIT}_epochs_{N_EPOCHS}'

### Identify training and validation data according to `mode` variable

In [None]:
#DATA_PATH = '/LOCAL/blja/PI_DeepONet_Data/eps0.1_n100_nx200_nt1000/'
# Train indices from examples 2, 4, 6, resp.

if mode == -1:
    
    train_idx_2 = np.arange(1, dtype=np.int16)
    train_idx_4 = np.arange(1, dtype=np.int16)
    train_idx_6 = np.arange(1, dtype=np.int16)
    
    val_idx_2 = np.arange(1, 2, dtype=np.int16)
    val_idx_4 = np.arange(1, 2, dtype=np.int16)
    val_idx_6 = np.arange(1, 2, dtype=np.int16)

elif mode == 0:
    
    train_idx_2 = np.arange(5, dtype=np.int16)
    train_idx_4 = np.arange(5, dtype=np.int16)
    train_idx_6 = np.arange(5, dtype=np.int16)
    
    val_idx_2 = np.arange(5, 6, dtype=np.int16)
    val_idx_4 = np.arange(5, 6, dtype=np.int16)
    val_idx_6 = np.arange(5, 6, dtype=np.int16)
    
elif mode == 1:
    
    train_idx_2 = np.arange(10, dtype=np.int16)
    train_idx_4 = np.arange(10, dtype=np.int16)
    train_idx_6 = np.arange(10, dtype=np.int16)
    
    val_idx_2 = np.arange(10, 12, dtype=np.int16)
    val_idx_4 = np.arange(10, 12, dtype=np.int16)
    val_idx_6 = np.arange(10, 12, dtype=np.int16)
    
elif mode == 2:
    
    train_idx_2 = np.arange(20, dtype=np.int16)
    train_idx_4 = np.arange(20, dtype=np.int16)
    train_idx_6 = np.arange(20, dtype=np.int16)
    
    val_idx_2 = np.arange(20, 24, dtype=np.int16)
    val_idx_4 = np.arange(20, 24, dtype=np.int16)
    val_idx_6 = np.arange(20, 24, dtype=np.int16) 
    
elif mode == 4:
    
    train_idx_2 = np.arange(40, dtype=np.int16)
    train_idx_4 = np.arange(40, dtype=np.int16)
    train_idx_6 = np.arange(40, dtype=np.int16)
    
    val_idx_2 = np.arange(95, 100, dtype=np.int16)
    val_idx_4 = np.arange(95, 100, dtype=np.int16)
    val_idx_6 = np.arange(95, 100, dtype=np.int16) 

## Load data

In [None]:
from src.dataHandling import loadData
DATA = loadData(DATA_PATH, train_idx_2, train_idx_4, train_idx_6)
FULL_VAL_DATA = loadData(DATA_PATH, val_idx_2, val_idx_4, val_idx_6)


Print shapes of data.

In [None]:
for i, di in enumerate(DATA):
    if i == 3 or i == 6:
        print('\n')
    for j, dij in enumerate(di):
        print(f'Shape of DATA[{i}][{j}]: {dij.shape}')

In [None]:
from src.networks_velocity import MLP, FF_MLP, modified_MLP, PI_DeepONet
from src.networks_velocity import n_res_data, n_init_data, n_bc_data

import optax

config.update("jax_enable_x64", False)
#jax.config.update('jax_default_device', gpu_device)

# Initialize model
m = 304 # Number of sensor postions (101 inflow, 101 outflow, 101 initial condition) + 1 velocity component

N_DATA_BC = 101

branch_layers = [m, N_WIDTH, N_WIDTH, N_WIDTH, N_WIDTH, N_WIDTH, N_WIDTH, N_WIDTH]
trunk_layers =  [2, N_WIDTH, N_WIDTH, N_WIDTH, N_WIDTH, N_WIDTH, N_WIDTH, N_WIDTH]

# graph is only used to get drift-diffusion PDE (the same for all examples)
graph = Example0(eps=EPS)

# Set default weights
loss_weights = dict({'res': 0.,
     'bcs': 0.,
     'ics': 1.,
     'physics': 1.,
     'physics_bcs_single_edge': 0.,
     'physics_bcs_inflow': 0.,
     'physics_bcs_inner': 0., 
     'physics_bcs_outflow': 0.,
    })

scheduler = optax.schedules.exponential_decay(
    init_value=1e-3,
    transition_steps=2000,
    decay_rate=0.9)

solver = optax.chain(
    optax.clip_by_global_norm(1.0),  # Clip by the gradient by the global norm.
    optax.scale_by_adam(),  # Use the updates from adam.
    optax.scale_by_schedule(scheduler),  # Use the learning rate from the scheduler.
    # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
    optax.scale(-1.0)
)
solver_is_lbfgs = False

model = PI_DeepONet(graph,
                    branch_layers,
                    trunk_layers,
                    branch_net=FF_MLP,
                    trunk_net=FF_MLP,
                    solver=solver,
                    solver_is_lbfgs=solver_is_lbfgs)


In [None]:
if train_model == 'inflow':
        
    if 'params_inflow' in locals():
        model.params = params_inflow
        model.opt_state = model.solver.init(params_inflow)

    loss_weights['physics_bcs_inflow']=10
    loss_weights['physics_bcs_inflow']=1
    
    RES_DATA, INIT_DATA, BC_DATA = DATA[0:3]
    VAL_DATA = FULL_VAL_DATA[0:3]

elif train_model == 'inner':
    
    if 'params_inner' in locals():
        model.params = params_inner
        model.opt_state = model.solver.init(params_inner)
    loss_weights['physics_bcs_inner']=10
    loss_weights['physics_bcs_inner']=1
    
    RES_DATA, INIT_DATA, BC_DATA = DATA[3:6] 
    VAL_DATA = FULL_VAL_DATA[3:6]

    
elif train_model == 'outflow':
    
    if 'params_outlow' in locals():
        model.params = params_outflow
        model.opt_state = model.solver.init(params_outflow)
    loss_weights['physics_bcs_outflow']=10
    loss_weights['physics_bcs_outflow']=1
    
    RES_DATA, INIT_DATA, BC_DATA = DATA[6:] 
    VAL_DATA = FULL_VAL_DATA[6:]

if N_RES_PER_SAMPLE < RES_DATA[1].shape[1]:
    key = random.PRNGKey(22)
    shuffle_idx = random.permutation(key, RES_DATA[1].shape[1])[:N_RES_PER_SAMPLE]
    RES_DATA[1] = RES_DATA[1][:, shuffle_idx, :]
    RES_DATA[2] = RES_DATA[2][:, shuffle_idx, :]
    
    VAL_DATA[0][1] = VAL_DATA[0][1][:, shuffle_idx, :]
    VAL_DATA[0][2] = VAL_DATA[0][2][:, shuffle_idx, :]

print(f'\nData for {train_model} model:')
print(f'Nummer of res data points: {n_res_data(RES_DATA)}')
print(f'Nummer of init data points: {n_init_data(INIT_DATA)}')
print(f'Nummer of bc data points: {n_bc_data(BC_DATA)}')


## Train model

In [None]:
training_key = random.PRNGKey(3)
model.train(training_key, nEpochs=N_EPOCHS,
            weights=loss_weights,
            RES_DATA=RES_DATA,
            INIT_DATA=INIT_DATA,
            BC_DATA=BC_DATA,
            N_SPLIT=N_SPLIT,
            N_DATA_BC=N_DATA_BC,
            VAL_DATA=VAL_DATA)

## Generate loss plot and store loss history

In [None]:
%matplotlib widget
plt.figure(figsize=(9,5), clear=True)
step = 1
L = np.arange(len(model.val_loss_res_log))[::step]

plt.semilogy(L, model.val_loss_res_log[::step],
             label='res_log', alpha=0.7)
plt.semilogy(L, model.val_loss_physics_log[::step],
             label='pde_ph_log', alpha=0.7)
plt.semilogy(L, model.val_loss_ics_log[::step],
             label='ics_log', alpha=0.7)
plt.semilogy(L, model.val_loss_bcs_log[::step],
             label='bcs_log', alpha=0.7)
plt.semilogy(L, model.val_loss_bnd_physics_log[::step],
             label='bnd_ph_log', alpha=0.7)
plt.gca().set_prop_cycle(None)

L = np.arange(len(model.train_loss_res_log))[::step]
plt.semilogy(L, model.train_loss_res_log[::step],
             linestyle='dashed',
             label='res_log', alpha=0.7)
plt.semilogy(L, model.train_loss_physics_log[::step],
             linestyle='dashed',
             label='pde_ph_log', alpha=0.7)
plt.semilogy(L, model.train_loss_ics_log[::step],
             linestyle='dashed',
             label='ics_log', alpha=0.7)
plt.semilogy(L, model.train_loss_bcs_log[::step],
             linestyle='dashed',
             label='bcs_log', alpha=0.7)
plt.semilogy(L, model.train_loss_bnd_physics_log[::step],
             linestyle='dashed',
             label='bnd_ph_log', alpha=0.7)

plt.title('Loss history')
ax = plt.gca()
ax.set_axisbelow(True)
ax.yaxis.grid(color='gray', linestyle='solid')
plt.legend()
plt.show()

LOSS_FIG_PATH = f'{PARAM_PATH}_loss_{model.val_loss_log[-1]:5.2e}.png'
plt.savefig(LOSS_FIG_PATH )

In [None]:
val_loss_log = [model.val_loss_log,
 model.val_loss_res_log,
 model.val_loss_physics_log,
 model.val_loss_ics_log,
 model.val_loss_bcs_log,
 model.val_loss_bnd_physics_log]

VAL_LOSS_CSV_PATH = f'{PARAM_PATH}_loss_{model.val_loss_log[-1]:5.2e}_VAL_LOSS.csv'
np.savetxt(VAL_LOSS_CSV_PATH, np.stack([np.array(v) for v in val_loss_log]), header='loss res ph ics bcs bcs_ph')

train_loss_log = [model.train_loss_log,
 model.train_loss_res_log,
 model.train_loss_physics_log,
 model.train_loss_ics_log,
 model.train_loss_bcs_log,
 model.train_loss_bnd_physics_log]

TRAIN_LOSS_CSV_PATH = f'{PARAM_PATH}_loss_{model.val_loss_log[-1]:5.2e}_TRAIN_LOSS.csv'
np.savetxt(TRAIN_LOSS_CSV_PATH, np.stack([np.array(v) for v in train_loss_log]), header='loss res ph ics bcs bcs_ph')

## Save final model

In [None]:
import pickle
from datetime import datetime

str_loss = f'{model.val_loss_log[-1]:5.2e}_FF'
str_datetime = datetime.now().strftime("%m-%d-%Y-%H%M%S")
PARAM_FILENAME = f'{PARAM_PATH}_loss_{str_loss}_{str_datetime}_FF.pkl'
pickle.dump(model.params, open(PARAM_FILENAME, 'wb'))

str_loss_best = f'{model.best_model_loss:5.2e}_FF_best'
PARAM_FILENAME_BEST = f'{PARAM_PATH}_loss_{str_loss_best}_{str_datetime}_FF.pkl'
pickle.dump(model.best_model_params, open(PARAM_FILENAME_BEST, 'wb'))