In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 1
!hostname
!pwd

dv002.bridges2.psc.edu
/ocean/projects/asc170022p/mtragoza/mre-pinn/notebooks


In [2]:
import sys, os
import numpy as np
import pandas as pd

sys.path.append('..')
%aimport mre_pinn

sys.path.append('../../param_search')
%aimport param_search
ps = param_search

Using backend: pytorch



# Training PINNs conditioned on MRI anatomic images

The objective is to see whether we can improve the performance of PINNs at elasticity reconstruction over a baseline method by conditioning the model on MRI anatomical images. We will use the BIOQIC phantom data set and evaluate at each of the different data frequencies separately.

In [3]:
# define the job template and name format

template = '''\
#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --account=asc170022p
#SBATCH --partition=GPU-shared
#SBATCH --gres=gpu:1
#SBATCH --time=48:00:00
#SBATCH -o %J.stdout
#SBATCH -e %J.stderr
#SBATCH --mail-type=all

hostname
pwd
source activate MRE-PINN

python ../../../train.py \\
    --data_root ../../../data/BIOQIC \\
    --data_name phantom \\
    --frequency {frequency} \\
    --xyz_slice {xyz_slice} \\
    --noise_ratio 0.0 \\
    --pde_name {pde_name} \\
    --omega0 {omega0} \\
    --n_layers {n_layers} \\
    --n_hidden {n_hidden} \\
    --activ_fn {activ_fn} \\
    --polar {polar} \\
    --conditional {conditional} \\
    --optimizer adam \\
    --learning_rate {learning_rate} \\
    --pde_loss_wt {pde_loss_wt} \\
    --data_loss_wt {data_loss_wt} \\
    --batch_size {batch_size} \\
    --n_iters {n_iters} \\
    --test_every {test_every} \\
    --save_every {save_every} \\
    --save_prefix {job_name}
'''
name = 'train_{frequency}_{xyz_slice}_{pde_name}_{omega0}_{conditional}'

# define the parameter space

param_space = ps.ParamSpace(
    frequency=[50, 60, 70, 80, 90, 100],
    xyz_slice=['2D'],
    pde_name=['helmholtz', 'hetero'],
    omega0=[1, 4, 16],
    n_layers=[5],
    n_hidden=[128],
    activ_fn=['s'],
    polar=[0, 1],
    conditional=[0, 1],
    learning_rate=1e-4,
    pde_loss_wt=1e-8,
    data_loss_wt=1,
    batch_size=128,
    n_iters=250000,
    test_every=1000,
    save_every=10000
)

len(param_space)

72

In [4]:
%autoreload
expt_name = '2022-10-03_conditional'

jobs = ps.submit(template, name, list(param_space)[:1], work_dir=expt_name, verbose=True)
jobs.to_csv(f'{expt_name}.jobs')

#import pandas as pd
#jobs = pd.read_csv(f'{expt_name}.jobs')

100%|██████████| 1/1 [00:00<00:00, 89.91it/s]
[11767592]


In [None]:
%autoreload
status_cols = ['job_name', 'job_state', 'node_id', 'runtime', 'stdout', 'stderr']
ps.status(jobs)[status_cols] #.iloc[0].stderr

In [None]:
print(ps.status(jobs)[status_cols].iloc[0].stdout)

In [None]:
metrics = ps.metrics(jobs).rename(columns={'mean_abs_value': 'median_abs_value'})
metrics

In [None]:
# did all models train to 100k iterations?
assert (metrics.groupby('job_name')['iteration'].max() == 250e3).all()

param_cols = ['pde_name', 'frequency', 'n_hidden'] # experimental parameters
index_cols = ['iteration', 'variable_name', 'spatial_frequency_bin', 'spatial_region'] # metric identifiers
metric_cols = ['mean_squared_abs_value', 'power_density', 'median_abs_value'] # metric values

group_cols = ['job_name'] + param_cols + index_cols
m = metrics.groupby(group_cols, sort=False)[metric_cols].mean() \
    .unstack(level=[group_cols.index('variable_name')])

def abbreviate_metrics(t):
    metric_name, var_name = t
    metric_name = {
        'mean_squared_abs_value': 'MSAV',
        'median_abs_value': 'MAV',
        'power_density': 'PSD'
    }[metric_name]
    new_col_name = f'{var_name}_{metric_name}'
    new_col_name = new_col_name.replace('diff_MSAV', 'pred_MSAE')
    new_col_name = new_col_name.replace('f_sum_MSAV', 'PDE_MSAE')
    new_col_name = new_col_name.replace('diff_MAV', 'pred_MAD')
    return new_col_name

m.columns = [abbreviate_metrics(x) for x in m.columns.to_flat_index()]

m = m.reset_index()
m

In [None]:
m.spatial_region.unique()

In [None]:
# plot the wave field error

m['u_pred_MSAE_rel'] = m['u_pred_MSAE'] / m['u_true_MSAV']
m['u_pred_MAD_rel'] = m['u_pred_MAD'] / m['u_true_MAV']

fig = ps.plot(
    m[(m.iteration > 200e3) & ~m.spatial_region.isin({'-1', 'all'})].copy(),
    x=param_cols,
    y=['u_pred_MAD', 'u_pred_MAD_rel'],
    grouped=True,
    height=4,
    width=3,
    legend=True,
    tight=True
)
fig.suptitle('Wave field error', x=0.5, y=0.98)
fig.tight_layout()

In [None]:
# plot the elastogram error

m['mu_pred_MSAE_rel'] = m['mu_pred_MSAE'] / m['mu_true_MSAV']
m['mu_pred_MAD_rel'] = m['mu_pred_MAD'] / m['mu_true_MAV']

fig = ps.plot(
    m[(m.iteration > 200e3) & ~m.spatial_region.isin({'-1', 'all'})].copy(),
    x=param_cols,
    y=['mu_pred_MAD', 'mu_pred_MAD_rel'],
    grouped=True,
    height=4,
    width=3,
    legend=True,
    tight=True
)
fig.suptitle('Elasticity error', x=0.5, y=0.98)
fig.tight_layout()

In [None]:
# plot the Laplacian error (model Laplacian vs finite differences) to assess overfitting

m['lu_pred_MSAE_rel'] = m['lu_pred_MSAE'] / m['lu_pred_MSAV']
m['lu_pred_MAD_rel'] = m['lu_pred_MAD'] / m['lu_pred_MAV']

fig = ps.plot(
    m[(m.iteration > 200e3) & ~m.spatial_region.isin({'-1', 'all'})].copy(),
    x=param_cols,
    y=['lu_pred_MAV', 'lu_pred_MAD_rel'],
    grouped=True,
    height=4,
    width=3,
    legend=True,
    tight=True
)
fig.suptitle('Laplacian deviation', x=0.5, y=0.98)
fig.tight_layout()