In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 1
!hostname
!pwd

v025.ib.bridges2.psc.edu
/ocean/projects/asc170022p/mtragoza/mre-pinn/MICCAI-2023


In [2]:
import sys, os
import numpy as np
import pandas as pd

sys.path.append('..')
%aimport mre_pinn

sys.path.append('../../param_search')
%aimport param_search
ps = param_search

Using backend: pytorch



# MICCAI 2023 simulation experiment

In [3]:
%pwd

'/ocean/projects/asc170022p/mtragoza/mre-pinn/MICCAI-2023'

In [4]:
template = '''\
#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --account=asc170022p
#SBATCH --partition=GPU-shared
#SBATCH --gres=gpu:1
#SBATCH --time=48:00:00
#SBATCH -o %J.stdout
#SBATCH -e %J.stderr
#SBATCH --mail-type=all

hostname
pwd
module load anaconda3
conda activate /ocean/projects/asc170022p/mtragoza/mambaforge/envs/MRE-PINN

python ../../../train.py \\
    --xarray_dir {xarray_dir} \\
    --example_id {example_id} \\
    --frequency {frequency} \\
    --noise_ratio {noise_ratio} \\
    --omega {omega} \\
    --n_layers {n_layers} \\
    --activ_fn {activ_fn} \\
    --polar_input {polar_input} \\
    --pde_name {pde_name} \\
    --pde_warmup_iters 10000 \\
    --pde_step_iters 5000 \\
    --pde_step_factor 10 \\
    --pde_init_weight {pde_init_weight} \\
    --pde_loss_weight {pde_loss_weight} \\
    --save_prefix {job_name}    
'''
name = 'train_{data_name}_{example_id}_{noise_ratio:.0e}_{savgol_filter}_{omega}_{pde_name}'

param_space = ps.ParamSpace(
    data_name='fem_box',
    xarray_dir='../../../data/BIOQIC/fem_box',
    example_id=[50, 60, 70, 80, 90, 100],
    frequency='auto',
    noise_ratio=[0, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1],
    omega=[60],
    n_layers=[5],
    activ_fn=['ss'],
    savgol_filter=[0, 1],
    pde_name=['helmholtz', 'hetero'],
    pde_init_weight=1e-10,
    pde_loss_weight=1e-8,
    polar_input=0,
)

len(param_space)

144

In [5]:
name.format(**list(param_space)[0])

'train_fem_box_50_0e+00_0_60_helmholtz'

In [6]:
%autoreload

expt_name = '2023-2-24_sim_noise2'

#jobs = ps.submit(template, name, list(param_space), work_dir=expt_name, verbose=True)
#jobs.to_csv(f'{expt_name}.jobs')

jobs = pd.read_csv(f'{expt_name}.jobs')

In [7]:
%autoreload
status_cols = ['job_name', 'job_state', 'node_id', 'runtime', 'stdout', 'stderr']
status = ps.status(jobs, parse_stdout=True, parse_stderr=True)
status[status_cols]

Unnamed: 0_level_0,job_name,job_state,node_id,runtime,stdout,stderr
job_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
14688853,train_fem_box_50_0e+00_0_60_helmholtz,,,,v031,
14688854,train_fem_box_50_0e+00_0_60_hetero,,,,v030,
14688855,train_fem_box_50_0e+00_1_60_helmholtz,,,,v022,
14688856,train_fem_box_50_0e+00_1_60_hetero,,,,v020,
14688857,train_fem_box_50_1e-05_0_60_helmholtz,,,,v026,
...,...,...,...,...,...,...
14688992,train_fem_box_100_1e-02_1_60_hetero,,,,v032,
14688993,train_fem_box_100_1e-01_0_60_helmholtz,,,,v003,
14688994,train_fem_box_100_1e-01_0_60_hetero,,,,v005,
14688995,train_fem_box_100_1e-01_1_60_helmholtz,,,,v014,


In [8]:
def get_error_type(e):
    for error_type in [
        'CANCELLED',
        'python: command not found',
        'Unexpected error from cudaGetDeviceCount()'
    ]:
        if error_type in e:
            return error_type
    return e

status['has_stderr_file'] = status.stderr.map(lambda x: not (isinstance(x, float) and np.isnan(x)))
status['has_stderr'] = ~status.stderr.isnull() & ~(status.stderr == '')
status['error'] = status.stderr.astype(str).map(get_error_type)

status.fillna('DONE').groupby(['job_state', 'has_stderr_file', 'has_stderr', 'error'])[['job_name']].count()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,job_name
job_state,has_stderr_file,has_stderr,error,Unnamed: 4_level_1
DONE,True,False,,144


In [9]:
status[status.has_stderr].groupby(['error', 'stdout'])[['job_name']].count()

Unnamed: 0_level_0,Unnamed: 1_level_0,job_name
error,stdout,Unnamed: 2_level_1


In [10]:
metrics = ps.metrics(jobs)

# did all models train to completion?
assert (metrics.groupby('job_name')['iteration'].max() == 100e3).all()

# get the final test evaluations
metrics = metrics[metrics.iteration == 100e3]

param_cols = ['pde_name', 'example_id', 'noise_ratio', 'savgol_filter'] # experimental parameters
index_cols = ['variable_name', 'spatial_frequency_bin', 'spatial_region'] # metric group columns
metric_cols = ['MSAV', 'PSD', 'MAV', 'R'] # metric value columns

metrics = metrics.groupby(param_cols + index_cols, sort=False)[metric_cols].mean()
metrics = metrics.unstack(level=[len(param_cols)])

def metric_map(t):
    metric_name, var_name = t
    new_col_name = f'{var_name}_{metric_name}'
    new_col_name = new_col_name.replace('diff_MSAV', 'pred_MSAE')
    new_col_name = new_col_name.replace('pde_diff_MSAV', 'PDE_MSAE')
    new_col_name = new_col_name.replace('diff_MAV', 'pred_MAD')
    return new_col_name

metrics.columns = [metric_map(t) for t in metrics.columns.to_flat_index()]
metrics

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,a_pred_MSAV,a_pred_MSAE,a_true_MSAV,u_pred_MSAV,u_pred_MSAE,u_true_MSAV,lu_pred_MSAV,lu_pred_MSAE,Lu_true_MSAV,pde_grad_MSAV,...,Lu_true_R,pde_grad_R,pde_diff_R,mu_diff_R,mu_pred_R,mu_true_R,direct_pred_R,direct_diff_R,fem_pred_R,fem_diff_R
pde_name,example_id,noise_ratio,savgol_filter,spatial_frequency_bin,spatial_region,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1
helmholtz,50,0.0,0,all,all,0.0,0.0,0.0,0.000107,6.762480e-09,0.000107,118515.014457,6686.630434,112126.908859,0.0,...,,,,,0.497088,,0.6448,,0.381062,
helmholtz,50,0.0,0,all,1,,,,,,,,,,,...,,,,,,,,,,
helmholtz,50,0.0,0,all,2,,,,,,,,,,,...,,,,,,,,,,
helmholtz,50,0.0,0,all,3,,,,,,,,,,,...,,,,,,,,,,
helmholtz,50,0.0,0,all,4,,,,,,,,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
hetero,100,0.1,1,6.0,all,,,,,,,,,,,...,,,,,,,,,,
hetero,100,0.1,1,7.0,all,,,,,,,,,,,...,,,,,,,,,,
hetero,100,0.1,1,8.0,all,,,,,,,,,,,...,,,,,,,,,,
hetero,100,0.1,1,9.0,all,,,,,,,,,,,...,,,,,,,,,,


In [11]:
m = metrics.reset_index()

fig = ps.plot(
    m[(m.spatial_region == 'all') & (m.spatial_frequency_bin == 'all')].copy(),
    x=param_cols,
    y=['mu_pred_MSAE', 'mu_pred_R', 'direct_pred_R', 'fem_pred_R'],
    height=2.5,
    width=2.5,
    legend=False,
    tight=True
)

<IPython.core.display.Javascript object>

In [12]:
m['pde_abbrev'] = m.pde_name.map({'helmholtz': 'HH', 'hetero': 'het'}.__getitem__)
m['method'] = 'PINN-' + m['pde_abbrev']

ahi_m = m[m.pde_name == 'helmholtz'].copy()
ahi_m['method'] = 'AHI'
for col in m.columns:
    if col.startswith('direct'):
        ahi_m[col.replace('direct', 'mu')] = m[col]
        
fem_m = m.copy()
fem_m['method'] = 'FEM-' + m['pde_abbrev']
for col in m.columns:
    if col.startswith('fem'):
        fem_m[col.replace('fem', 'mu')] = m[col]
        
mm = pd.concat([ahi_m, fem_m, m])

method_order = ['AHI', 'FEM-HH', 'FEM-het', 'PINN-HH', 'PINN-het']
region_order = ['all', '1', '2', '3', '4', '5']
mm['method_index'] = mm['method'].map(method_order.index)
mm['region_index'] = mm['spatial_region'].map(region_order.index)
mm = mm.sort_values(['method_index', 'region_index'])

In [13]:
import matplotlib as mpl
pct_format = mpl.ticker.PercentFormatter()

mm['mu_pred_MSAE_relative'] = mm['mu_pred_MSAE'] / mm['mu_true_MSAV'] * 100
mm['mu_pred_MAD_relative'] = mm['mu_pred_MAD'] / mm['mu_true_MAV'] * 100

In [15]:
%autoreload
import seaborn as sns
import matplotlib.pyplot as plt

colors = ps.results.get_color_palette(10, type='bright', min_val=0)
blue, orange, green, red, purple, brown, pink, gray, yellow, cyan = colors

colors = [red, yellow, green, cyan, blue]
sns.set_palette(colors)
colors = sns.color_palette()
colors

In [16]:
import matplotlib as mpl

fig = ps.plot(
    mm[
        (mm.spatial_frequency_bin == 'all') &
        (mm.noise_ratio == 0.0) &
        ~mm.savgol_filter
    ].copy(),
    x='method',
    y='mu_pred_MAD_relative',
    height=3,
    width=7,
    tight=True,
    legend=False,
    plot_func=ps.results.barplot
)
fig.axes[0].set_ylabel('$\mu$ relative MAD (%)')
fig.axes[0].yaxis.set_major_formatter(pct_format)
fig.tight_layout()

for i, patch in enumerate(fig.axes[0].patches):
    plt.setp(patch, facecolor=colors[i%len(colors)])

for ext in ['png', 'pdf']:
    fig.savefig(f'images/fem_box_method_bar_plot.{ext}', bbox_inches='tight', dpi=200)

<IPython.core.display.Javascript object>



In [17]:
mm[
    (mm.spatial_region != 'all') &
    (mm.noise_ratio == 0.0) &
    ~mm.savgol_filter
].groupby(['method'])[['mu_pred_MAD_relative']].describe()

Unnamed: 0_level_0,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative
Unnamed: 0_level_1,count,mean,std,min,25%,50%,75%,max
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
AHI,30.0,52.138892,30.582165,6.464485,19.596435,65.875894,71.935369,103.329509
FEM-HH,30.0,49.444693,13.943769,25.031245,38.504028,49.08636,58.463774,78.105472
FEM-het,30.0,56.349154,17.059764,24.273328,43.235177,59.661668,69.756955,85.206215
PINN-HH,30.0,65.624415,25.534887,17.066152,63.408148,69.411755,77.919801,106.172179
PINN-het,30.0,45.882514,19.961549,13.259699,32.355035,48.776573,62.391329,81.51141


In [18]:
fig = ps.plot(
    mm[
        (mm.spatial_region == 'all') &
        (mm.spatial_frequency_bin == 'all') &
        (mm.noise_ratio == 0.0) &
        ~mm.savgol_filter
    ].copy(),
    x='method',
    hue='method',
    y='mu_pred_R',
    height=3,
    width=7,
    legend=False,
    tight=True,
    plot_func=ps.results.barplot
)
fig.axes[0].set_ylabel('$\mu$ correlation (R)')
fig.axes[0].set_ylim([0, 1])

for i, patch in enumerate(fig.axes[0].patches):
    plt.setp(patch, facecolor=colors[i])

for ext in ['png', 'pdf']:
    fig.savefig(f'images/fem_box_method_R_bar_plot.{ext}', bbox_inches='tight', dpi=200)

<IPython.core.display.Javascript object>



In [19]:
mm[
    (mm.spatial_region == 'all') &
    (mm.spatial_frequency_bin == 'all') &
    (mm.noise_ratio == 0.0) &
    ~mm.savgol_filter
].groupby(['method'])[['mu_pred_R']].describe()

Unnamed: 0_level_0,mu_pred_R,mu_pred_R,mu_pred_R,mu_pred_R,mu_pred_R,mu_pred_R,mu_pred_R,mu_pred_R
Unnamed: 0_level_1,count,mean,std,min,25%,50%,75%,max
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
AHI,6.0,0.540157,0.114912,0.363666,0.469514,0.594084,0.610758,0.6448
FEM-HH,6.0,0.52027,0.082537,0.381062,0.494133,0.536975,0.556485,0.623765
FEM-het,6.0,0.401579,0.08736,0.327193,0.34454,0.362275,0.445254,0.54619
PINN-HH,6.0,0.527878,0.04674,0.464874,0.499946,0.527807,0.552445,0.59546
PINN-het,6.0,0.698761,0.072817,0.585499,0.65457,0.720973,0.75325,0.768675


In [20]:
[c for c in mm if c.startswith('mu')]

['mu_pred_MSAE',
 'mu_pred_MSAV',
 'mu_true_MSAV',
 'mu_diff_PSD',
 'mu_pred_PSD',
 'mu_true_PSD',
 'mu_pred_MAD',
 'mu_pred_MAV',
 'mu_true_MAV',
 'mu_diff_R',
 'mu_pred_R',
 'mu_true_R',
 'mu_pred_MSAE_relative',
 'mu_pred_MAD_relative']

In [21]:
mm['noise_level'] = 10 * np.log10(mm['noise_ratio'])
mm.loc[np.isinf(mm.noise_level), 'noise_level'] = -60

  result = getattr(ufunc, method)(*inputs, **kwargs)



In [22]:
# contrast metrics
mmm = mm[
    (mm.spatial_frequency_bin == 'all') &
    (mm.spatial_region != 'all')
]
regions = mmm.spatial_region.unique()
mmm = mmm.set_index(index_cols[2:] + param_cols + ['method'])
mmm = mmm.unstack(level=0)
for r in regions:
    mmm['mu_pred_MAC', r] = mmm['mu_pred_MAV', r] - mmm['mu_pred_MAV', '1']
    mmm['mu_true_MAC', r] = mmm['mu_true_MAV', r] - mmm['mu_true_MAV', '1']

mmm = mmm.stack().reset_index().sort_values(['method_index', 'region_index'])
mmm['mu_pred_CTE'] = mmm['mu_pred_MAC'] / mmm['mu_true_MAC'] * 100
mmm

Unnamed: 0,pde_name,example_id,noise_ratio,savgol_filter,method,spatial_region,Lu_true_MAV,Lu_true_MSAV,Lu_true_PSD,Lu_true_R,...,u_pred_MAV,u_pred_MSAE,u_pred_MSAV,u_pred_PSD,u_pred_R,u_true_MAV,u_true_MSAV,u_true_PSD,u_true_R,mu_pred_CTE
0,helmholtz,50,0.00000,0,AHI,1,30.129231,,,,...,0.001412,,,,,0.001413,,,,
15,helmholtz,50,0.00000,1,AHI,1,30.129231,,,,...,0.001413,,,,,0.001413,,,,
30,helmholtz,50,0.00001,0,AHI,1,51.835527,,,,...,0.001411,,,,,0.001413,,,,
45,helmholtz,50,0.00001,1,AHI,1,51.807668,,,,...,0.001412,,,,,0.001412,,,,
60,helmholtz,50,0.00010,0,AHI,1,127.715910,,,,...,0.001414,,,,,0.001418,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1759,hetero,100,0.00100,1,PINN-het,5,100.709811,,,,...,0.000175,,,,,0.000187,,,,2.373984
1769,hetero,100,0.01000,0,PINN-het,5,245.107672,,,,...,0.000158,,,,,0.000321,,,,1.646036
1779,hetero,100,0.01000,1,PINN-het,5,223.788004,,,,...,0.000182,,,,,0.000339,,,,-2.855406
1789,hetero,100,0.10000,0,PINN-het,5,604.407490,,,,...,0.000184,,,,,0.000795,,,,6.080382


In [23]:
mmm[['mu_pred_MAC', 'mu_true_MAC', 'mu_pred_CTE']]

Unnamed: 0,mu_pred_MAC,mu_true_MAC,mu_pred_CTE
0,0.000000,0.000000,
15,0.000000,0.000000,
30,0.000000,0.000000,
45,0.000000,0.000000,
60,0.000000,0.000000,
...,...,...,...
1759,165.101798,6954.628547,2.373984
1769,114.475663,6954.628547,1.646036
1779,-198.582869,6954.628547,-2.855406
1789,422.867950,6954.628547,6.080382


In [24]:
fig = ps.plot(
    mmm[
        (mmm.spatial_region.isin(set('2345'))) &
        (mmm.noise_ratio == 0.0) &
        ~mmm.savgol_filter
    ],
    x='spatial_region',
    y='mu_pred_CTE',
    hue='method',
    height=3,
    width=9,
    legend=True,
    tight=True,
    plot_func=ps.results.barplot
)
fig.axes[0].set_ylim(0, 150)
fig.axes[0].set_ylabel('$\mu$ CTE (%)')
fig.axes[0].yaxis.set_major_formatter(pct_format)
fig.tight_layout()

fig.axes[0].set_xlabel('spatial region')
fig.axes[0].set_xticklabels(['target 1', 'target 2', 'target 3', 'target 4'])

for ext in ['png', 'pdf']:
    fig.savefig(f'images/fem_box_contrast_bar_plot.{ext}', bbox_inches='tight', dpi=200)

<IPython.core.display.Javascript object>

In [25]:
mmm[
    (mmm.spatial_region.isin(set('2345'))) &
    (mmm.noise_ratio == 0.0) &
    ~mmm.savgol_filter
].groupby(['spatial_region', 'method'])[['mu_pred_CTE']].describe()

Unnamed: 0_level_0,Unnamed: 1_level_0,mu_pred_CTE,mu_pred_CTE,mu_pred_CTE,mu_pred_CTE,mu_pred_CTE,mu_pred_CTE,mu_pred_CTE,mu_pred_CTE
Unnamed: 0_level_1,Unnamed: 1_level_1,count,mean,std,min,25%,50%,75%,max
spatial_region,method,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
2,AHI,6.0,97.790984,12.070345,86.385616,90.003452,94.022825,102.034785,118.945535
2,FEM-HH,6.0,108.929942,39.192674,59.55606,80.181085,111.219947,130.691343,164.567089
2,FEM-het,6.0,87.547481,37.157222,53.657147,61.44825,80.869108,95.470679,154.577021
2,PINN-HH,6.0,46.301104,14.043387,24.771183,39.723753,46.129068,56.63716,63.048798
2,PINN-het,6.0,50.910568,16.094744,28.491055,39.962591,52.971099,63.423022,68.496739
3,AHI,6.0,103.5778,52.748358,17.200685,77.561858,116.383852,144.130822,153.497406
3,FEM-HH,6.0,37.899849,23.108028,6.896111,27.449588,36.043296,44.890447,75.991876
3,FEM-het,6.0,29.5411,20.082345,11.435213,13.715927,26.409845,36.548741,63.578702
3,PINN-HH,6.0,35.117533,30.698264,-8.616966,14.159644,40.497941,57.456516,69.836695
3,PINN-het,6.0,40.561392,13.02975,20.949503,33.042727,43.388641,48.162752,56.293355


In [26]:
fig.axes

[<AxesSubplot: xlabel='spatial region', ylabel='$\\mu$ CTE (%)'>]

In [27]:
sns.set_context('talk')

fig = ps.plot(
    mmm[mmm.spatial_region.isin(set('2345')) & ~mmm.savgol_filter],
    x='noise_level',
    y='mu_pred_CTE',
    hue='method',
    height=8,
    width=8,
    legend=True,
    legend_kws=dict(loc='upper right'),
    tight=True,
    plot_func=sns.lineplot
)
fig.axes[0].set_xlabel('noise level (dB)')
fig.axes[0].set_xticks([-60, -50, -40, -30, -20, -10])
fig.axes[0].set_xticklabels(['-inf', '-50', '-40', '-30', '-20', '-10'])

fig.axes[0].set_ylabel('contrast transfer efficiency (%)')
fig.axes[0].yaxis.set_major_formatter(pct_format)
fig.axes[0].set_ylim(0, 85)
fig.tight_layout()

fig.axes[0].set_xlabel('noise level (dB)')
fig.axes[0].set_xticks([-60, -50, -40, -30, -20, -10])
fig.axes[0].set_xticklabels(['-inf', '-50', '-40', '-30', '-20', '-10'])

for ext in ['png', 'pdf']:
    fig.savefig(f'images/fem_box_noise_CTE_plot.{ext}', bbox_inches='tight', dpi=200)

<IPython.core.display.Javascript object>

In [28]:
mmm[
    mmm.spatial_region.isin(set('2345')) & ~mmm.savgol_filter
].groupby(['method', 'noise_level'])[['mu_pred_CTE']].describe()

Unnamed: 0_level_0,Unnamed: 1_level_0,mu_pred_CTE,mu_pred_CTE,mu_pred_CTE,mu_pred_CTE,mu_pred_CTE,mu_pred_CTE,mu_pred_CTE,mu_pred_CTE
Unnamed: 0_level_1,Unnamed: 1_level_1,count,mean,std,min,25%,50%,75%,max
method,noise_level,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
AHI,-60.0,24.0,51.594662,56.466007,-16.71385,2.635651,15.681878,98.25094,153.497406
AHI,-50.0,24.0,39.745914,43.075296,-14.34495,3.468391,12.734066,72.506382,124.529209
AHI,-40.0,24.0,21.911516,24.616176,-9.791818,1.673378,12.445724,41.168756,76.096459
AHI,-30.0,24.0,6.614186,9.157041,-9.332295,0.615886,3.247031,13.114796,28.258627
AHI,-20.0,24.0,0.818476,3.196871,-5.242409,-1.017765,-0.119299,2.006535,7.589069
AHI,-10.0,24.0,0.144442,1.111623,-2.064324,-0.511847,-0.117007,0.845278,2.576926
FEM-HH,-60.0,24.0,47.383,44.902621,-8.708746,16.038457,36.897571,62.628693,164.567089
FEM-HH,-50.0,24.0,0.249679,5.152029,-10.160486,-2.658329,-1.16749,1.945469,12.621778
FEM-HH,-40.0,24.0,-0.355787,1.127511,-2.53307,-1.231113,-0.320204,0.701276,1.809569
FEM-HH,-30.0,24.0,0.039662,0.336024,-0.677492,-0.228559,0.012624,0.255335,0.725342


In [29]:
# two-sample t tests
from scipy.stats import ttest_ind, ttest_rel

for method, df in mmm[
    mmm.spatial_region.isin(set('2345')) & ~mmm.savgol_filter
].groupby('method'):
    print(method)
    df = df.set_index(['example_id', 'spatial_region'])
    sample1 = df[df.noise_level == -60]
    for ndb in [-50, -40, -30, -20, -10]:
        sample2 = df[df.noise_level == ndb]
        res = ttest_rel(sample2.mu_pred_CTE, sample1.mu_pred_CTE, alternative='less')
        print(ndb, res)


AHI
-50 Ttest_relResult(statistic=-3.6831001366170355, pvalue=0.000615718333184049)
-40 Ttest_relResult(statistic=-4.215398903877769, pvalue=0.00016458001065866216)
-30 Ttest_relResult(statistic=-4.435493154127543, pvalue=9.506488440987507e-05)
-20 Ttest_relResult(statistic=-4.494565193267231, pvalue=8.203982339987766e-05)
-10 Ttest_relResult(statistic=-4.469388553400962, pvalue=8.735727641606034e-05)
FEM-HH
-50 Ttest_relResult(statistic=-5.131487387175304, pvalue=1.6842561764556174e-05)
-40 Ttest_relResult(statistic=-5.152522971519955, pvalue=1.598952984516903e-05)
-30 Ttest_relResult(statistic=-5.144855868094972, pvalue=1.629527024781466e-05)
-20 Ttest_relResult(statistic=-5.16633109255666, pvalue=1.5453445770117206e-05)
-10 Ttest_relResult(statistic=-5.17144802082538, pvalue=1.5259431193151796e-05)
FEM-het
-50 Ttest_relResult(statistic=-5.3808867246409084, pvalue=9.11225767575716e-06)
-40 Ttest_relResult(statistic=-5.373568957631742, pvalue=9.277392072226225e-06)
-30 Ttest_relResult

In [30]:
fig = ps.plot(
    mm[
        (mm.spatial_region != 'all') &
        (mm.noise_ratio == 0.0) &
        ~mm.savgol_filter
    ],
    x='spatial_region',
    y='mu_pred_MAD_relative',
    hue='method',
    height=3,
    width=9,
    legend=True,
    tight=True,
    plot_func=ps.results.barplot
)
fig.axes[0].set_ylabel('$\mu$ relative MAD (%)')
fig.axes[0].yaxis.set_major_formatter(pct_format)
fig.tight_layout()

fig.axes[0].set_xlabel('spatial region')
fig.axes[0].set_xticklabels(['background', 'target 1', 'target 2', 'target 3', 'target 4'])

for ext in ['png', 'pdf']:
    fig.savefig(f'images/fem_box_region_bar_plot.{ext}', bbox_inches='tight', dpi=200)

<IPython.core.display.Javascript object>

In [31]:
fig = ps.plot(
    mm[(mm.spatial_region != 'all') & (mm.noise_ratio == 0.0)],
    x='example_id',
    y='mu_pred_MAD_relative',
    hue='method',
    height=5,
    width=7,
    legend=True,
    tight=True,
    plot_func=sns.lineplot
)
fig.axes[0].set_ylabel('$\mu$ relative MAD (%)')
fig.axes[0].yaxis.set_major_formatter(pct_format)
fig.axes[0].set_ylim(0, 250)
fig.tight_layout()

fig.axes[0].set_xlabel('frequency (Hz)')

for ext in ['png', 'pdf']:
    fig.savefig(f'images/fem_box_frequency_bar_plot.{ext}', bbox_inches='tight', dpi=200)

<IPython.core.display.Javascript object>

In [32]:
fig = ps.plot(
    mm[(mm.spatial_region != 'all')],
    x='noise_level',
    y='mu_pred_MAD_relative',
    hue='method',
    height=5,
    width=7,
    legend=True,
    tight=True,
    plot_func=sns.lineplot
)
fig.axes[0].set_ylabel('$\mu$ relative MAD (%)')
fig.axes[0].yaxis.set_major_formatter(pct_format)
fig.axes[0].set_ylim(0, 140)
fig.tight_layout()

fig.axes[0].set_xlabel('noise level (dB)')
fig.axes[0].set_xticks([-60, -50, -40, -30, -20, -10])
fig.axes[0].set_xticklabels(['-inf', '-50', '-40', '-30', '-20', '-10'])

for ext in ['png', 'pdf']:
    fig.savefig(f'images/fem_box_noise_bar_plot.{ext}', bbox_inches='tight', dpi=200)

<IPython.core.display.Javascript object>

In [33]:
mm[
    (mm.spatial_region != 'all') &
    ~mm.savgol_filter
].groupby(['method', 'noise_level'])[['mu_pred_MAD_relative']].describe()

Unnamed: 0_level_0,Unnamed: 1_level_0,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative
Unnamed: 0_level_1,Unnamed: 1_level_1,count,mean,std,min,25%,50%,75%,max
method,noise_level,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
AHI,-60.0,30.0,52.138892,30.582165,6.464485,19.596435,65.875894,71.935369,103.329509
AHI,-50.0,30.0,54.375361,27.984106,16.25885,22.235944,66.071253,71.490693,103.341829
AHI,-40.0,30.0,62.475131,24.067502,19.074933,45.590207,69.541674,76.56392,101.769995
AHI,-30.0,30.0,77.399151,19.005485,33.938863,76.30925,80.080594,91.440775,102.661275
AHI,-20.0,30.0,90.404184,9.421328,64.70356,89.743211,92.128597,96.568284,100.690668
AHI,-10.0,30.0,96.550963,3.431239,85.49581,95.941968,97.226708,98.760257,100.195112
FEM-HH,-60.0,30.0,49.444693,13.943769,25.031245,38.504028,49.08636,58.463774,78.105472
FEM-HH,-50.0,30.0,81.461496,12.984962,50.979723,77.497401,85.427275,90.737015,94.909913
FEM-HH,-40.0,30.0,96.241332,3.518605,87.035854,96.471317,97.363345,98.328879,99.43941
FEM-HH,-30.0,30.0,98.5912,1.185899,94.665915,98.508555,98.944106,99.329434,99.627757


In [34]:
fig, ax = plt.subplots(1, 2, figsize=(8.5, 3), width_ratios=(0.7, 0.3), squeeze=False)

ps.results.barplot(
    mm[(mm.spatial_region != 'all') & (mm.noise_ratio == 0.0) & ~mm.savgol_filter],
    x='spatial_region',
    y='mu_pred_MAD_relative',
    hue='method',
    ax=ax[0,0]
)
ax[0,0].grid(linestyle=':')
ax[0,0].axes.set_axisbelow(True)
ax[0,0].legend(frameon=True, edgecolor='w', loc='upper left', fontsize='x-small', ncol=2)

t = mpl.transforms.ScaledTranslation(-0.6, 0, fig.dpi_scale_trans)
ax[0,0].text(
    0, 1, 'A', fontweight='bold', va='bottom', ha='right',
    transform=ax[0,0].transAxes + t
)

ax[0,0].set_ylabel('$\mu$ relative MAD (%)')
ax[0,0].set_ylim(0, 105)
ax[0,0].yaxis.set_major_formatter(pct_format)

ax[0,0].set_xlabel('spatial region')
ax[0,0].set_xticklabels(['background', 'target 1', 'target 2', 'target 3', 'target 4'])

sns.lineplot(
    mm[(mm.spatial_region != 'all') & ~mm.savgol_filter],
    x='noise_level',
    y='mu_pred_MAD_relative',
    hue='method',
    ax=ax[0,1]
)
ax[0,1].grid(linestyle=':')
ax[0,1].axes.set_axisbelow(True)
ax[0,1].legend(frameon=True, edgecolor='w', fontsize='x-small', ncol=2, loc='lower right')

t = mpl.transforms.ScaledTranslation(-0.4, 0, fig.dpi_scale_trans)
ax[0,1].text(0, 1, 'B', fontweight='bold', va='bottom', ha='right', transform=ax[0,1].transAxes + t)

ax[0,1].set_ylabel(None)
ax[0,1].set_ylim(0, 105)
ax[0,1].yaxis.set_major_formatter(pct_format)

ax[0,1].set_xlabel('noise level (dB)')
ax[0,1].set_xticks([-60, -50, -40, -30, -20, -10])
ax[0,1].set_xticklabels(['-inf', '-50', '-40', '-30', '-20', '-10'])

sns.despine(fig)
fig.tight_layout()

for ext in ['png', 'pdf']:
    fig.savefig(f'images/fem_box_plots.{ext}', bbox_inches='tight', dpi=200)

<IPython.core.display.Javascript object>

In [35]:
mm[
    (mm.spatial_region != 'all') &
    (mm.noise_ratio == 0.0) &
    ~mm.savgol_filter
].groupby(['method', 'spatial_region'])[['mu_pred_MAD_relative']].describe()

Unnamed: 0_level_0,Unnamed: 1_level_0,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative,mu_pred_MAD_relative
Unnamed: 0_level_1,Unnamed: 1_level_1,count,mean,std,min,25%,50%,75%,max
method,spatial_region,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2
AHI,1,6.0,20.477379,3.751072,16.015281,18.65143,19.654685,21.611978,26.958204
AHI,2,6.0,14.519883,8.086173,6.464485,9.367848,12.235311,17.634529,28.337897
AHI,3,6.0,75.8804,12.898376,67.540044,69.426491,72.059835,72.960188,101.813668
AHI,4,6.0,81.106068,15.260826,65.378939,68.720082,78.808711,90.917179,103.329509
AHI,5,6.0,68.710727,5.640336,65.196416,65.705036,66.380598,68.273767,79.922081
FEM-HH,1,6.0,49.713624,8.056941,39.419135,43.278905,50.775084,55.943344,58.866166
FEM-HH,2,6.0,38.721534,13.555974,25.031245,29.515436,35.66284,44.09098,61.605617
FEM-HH,3,6.0,47.626957,13.319063,35.532191,37.279142,45.243418,52.103181,70.728564
FEM-HH,4,6.0,61.798584,11.477178,52.42422,53.018843,56.633862,70.543653,78.105472
FEM-HH,5,6.0,49.362768,15.642861,36.433334,38.813813,41.551845,60.188639,72.337547


In [36]:
mmm['noise_level'] = 10 * np.log10(mmm['noise_ratio'])
mmm.loc[np.isinf(mmm.noise_level), 'noise_level'] = -60

  result = getattr(ufunc, method)(*inputs, **kwargs)



In [37]:
fig, ax = plt.subplots(1, 2, figsize=(8.5, 3), width_ratios=(0.7, 0.3), squeeze=False)

ps.results.barplot(
    mmm[mmm.spatial_region.isin(set('2345')) & (mmm.noise_ratio == 0.0) & ~mmm.savgol_filter],
    x='spatial_region',
    y='mu_pred_CTE',
    hue='method',
    ax=ax[0,0]
)
ax[0,0].grid(linestyle=':')
ax[0,0].axes.set_axisbelow(True)
ax[0,0].legend(frameon=True, edgecolor='w', loc='upper right', fontsize='small', ncol=2)

t = mpl.transforms.ScaledTranslation(-0.6, 0, fig.dpi_scale_trans)
ax[0,0].text(
    0, 1, 'A', fontweight='bold', va='bottom', ha='right',
    transform=ax[0,0].transAxes + t
)

ax[0,0].set_ylabel('CTE (%)')
ax[0,0].set_ylim(0, 125)
ax[0,0].yaxis.set_major_formatter(pct_format)

ax[0,0].set_xlabel('spatial region')
ax[0,0].set_xticklabels(['target 1', 'target 2', 'target 3', 'target 4'])

sns.lineplot(
    mmm[mmm.spatial_region.isin(set('2345')) & ~mmm.savgol_filter],
    x='noise_level',
    y='mu_pred_CTE',
    hue='method',
    ax=ax[0,1]
)
ax[0,1].grid(linestyle=':')
ax[0,1].axes.set_axisbelow(True)
ax[0,1].legend(frameon=True, edgecolor='w', fontsize='small', ncol=1, loc='upper right')

t = mpl.transforms.ScaledTranslation(-0.4, 0, fig.dpi_scale_trans)
ax[0,1].text(0, 1, 'B', fontweight='bold', va='bottom', ha='right', transform=ax[0,1].transAxes + t)

ax[0,1].set_ylabel('CTE (%)')
ax[0,1].set_ylim(0, 85)
ax[0,1].yaxis.set_major_formatter(pct_format)

ax[0,1].set_xlabel('noise level (dB)')
ax[0,1].set_xticks([-60, -50, -40, -30, -20, -10])
ax[0,1].set_xticklabels(['-inf', '-50', '-40', '-30', '-20', '-10'])

sns.despine(fig)
fig.tight_layout()

for ext in ['png', 'pdf']:
    fig.savefig(f'images/fem_box_contrast_plots.{ext}', bbox_inches='tight', dpi=200)

<IPython.core.display.Javascript object>

In [70]:
%autoreload

blue, orange, green, red, purple, brown, pink, gray, yellow, cyan = sns.color_palette('bright')

mre_pinn.visual.COLORS.update(
    red=red,
    yellow=(1, 0.9, 0),
    green=green,
    cyan=(0, 0.8, 0.8),
    blue=blue,
    purple=(1/2, 0, 1/2),
    black=(0, 0, 0),
)

In [71]:
wave_kws = dict(vmin=-1e-2, vmax=1e-2, cmap=mre_pinn.visual.wave_color_map())
elast_kws = dict(vmin=0, vmax=20e3, cmap=mre_pinn.visual.mre_color_map(symmetric=False))

dataset = mre_pinn.data.MREDataset.load_xarrays(
    xarray_dir='../data/BIOQIC/fem_box2'
)
dataset[0].view('mre_mask', ax_height=1.5, space=0.2)
plt.gcf().tight_layout()
plt.savefig('images/fem_box_regions.png', dpi=200, bbox_inches='tight')
dataset[0].view('mre', ax_height=1.5, space=0.2, **elast_kws)
plt.gcf().tight_layout()
plt.savefig('images/fem_box_mre.png', dpi=200, bbox_inches='tight')

Loading ../data/BIOQIC/fem_box2/100/wave.nc
Loading ../data/BIOQIC/fem_box2/100/mre.nc
Loading ../data/BIOQIC/fem_box2/100/mre_mask.nc
Loading ../data/BIOQIC/fem_box2/50/wave.nc
Loading ../data/BIOQIC/fem_box2/50/mre.nc
Loading ../data/BIOQIC/fem_box2/50/mre_mask.nc
Loading ../data/BIOQIC/fem_box2/60/wave.nc
Loading ../data/BIOQIC/fem_box2/60/mre.nc
Loading ../data/BIOQIC/fem_box2/60/mre_mask.nc
Loading ../data/BIOQIC/fem_box2/70/wave.nc
Loading ../data/BIOQIC/fem_box2/70/mre.nc
Loading ../data/BIOQIC/fem_box2/70/mre_mask.nc
Loading ../data/BIOQIC/fem_box2/80/wave.nc
Loading ../data/BIOQIC/fem_box2/80/mre.nc
Loading ../data/BIOQIC/fem_box2/80/mre_mask.nc
Loading ../data/BIOQIC/fem_box2/90/wave.nc
Loading ../data/BIOQIC/fem_box2/90/mre.nc
Loading ../data/BIOQIC/fem_box2/90/mre_mask.nc


<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

  plt.gcf().tight_layout()



<IPython.core.display.Javascript object>

interactive(children=(SelectionSlider(description='part', options=(('real', 0), ('imag', 1)), value=0), Select…

  plt.gcf().tight_layout()



In [92]:
example = dataset[0]
example.mre_mask

<IPython.core.display.Javascript object>

In [72]:
%autoreload
import mre_pinn

# image grid

image_names = [
    'wave image', 'ground truth', 'AHI', 'FEM-HH', 'FEM-het', 'PINN-HH', 'PINN-het'
]

expt_name = '2023-2-24_sim_noise2' 
name = 'train_fem_box_{example_id}_{noise_ratio:.0e}_{savgol_filter}_{omega}_{pde_name}'

def plot_image_grid(example_ids):
    n_rows = len(example_ids)
    n_cols = len(image_names)
    ax_width = 1
    ax_height = ax_width / 0.8
    cbar_width = 0.1
    
    noise_ratio = 0
    savgol_filter = 0
    omega = 60

    fig, axes, cbar_ax = mre_pinn.visual.subplot_grid(
        n_rows, n_cols, ax_height, ax_width, cbar_width, space=0, pad=(0.35,0.55,0.15,0.25)
    )
    for row_idx, example_id in enumerate(example_ids):
        example = mre_pinn.data.MREExample.load_xarrays('../data/BIOQIC/fem_box2', example_id)
        for col_idx, image_name in enumerate(image_names):
            ax = axes[row_idx,col_idx]
            #ax.text(0.1, 0.1, f'{row_idx}, {col_idx}')
            if col_idx == 0 and False:
                ax.set_ylabel(example_id + ' Hz', fontsize='medium')
            if row_idx == 0:
                ax.set_title(image_name, fontsize='small')
            ax.set_yticks([])
            ax.set_xticks([])
            
            if image_name == 'regions':
                array = example.mre_mask.sel(z=0)
                color_kws = mre_pinn.visual.get_color_kws(array)

            elif image_name == 'wave image':
                array = example.wave.sel(component='z', z=0)
                color_kws = wave_kws
                array = array.real

            elif image_name in {'ground truth', 'elastogram'}:
                array = example.mre.sel(z=0)
                color_kws = elast_kws
                array = np.abs(array)

            elif image_name == 'AHI':
                job_name = name.format(
                    example_id=example_id, pde_name='helmholtz', noise_ratio=noise_ratio, omega=omega, savgol_filter=savgol_filter
                )
                nc_file = f'{expt_name}/{job_name}/{job_name}_direct.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='direct_pred', z=0)
                color_kws = elast_kws
                array = np.abs(array)
                
            elif image_name == 'FEM-HH':
                job_name = name.format(
                    example_id=example_id, pde_name='helmholtz', noise_ratio=noise_ratio, omega=omega, savgol_filter=savgol_filter
                )
                nc_file = f'{expt_name}/{job_name}/{job_name}_fem.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='fem_pred', z=0)
                color_kws = elast_kws
                array = np.abs(array)

            elif image_name == 'FEM-het':
                job_name = name.format(
                    example_id=example_id, pde_name='hetero', noise_ratio=noise_ratio, omega=omega, savgol_filter=savgol_filter
                )
                nc_file = f'{expt_name}/{job_name}/{job_name}_fem.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='fem_pred', z=0)
                color_kws = elast_kws
                array = np.abs(array)

            elif image_name == 'PINN-HH':
                job_name = name.format(
                    example_id=example_id, pde_name='helmholtz', noise_ratio=noise_ratio, omega=omega, savgol_filter=savgol_filter
                )
                nc_file = f'{expt_name}/{job_name}/{job_name}_elastogram.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='mu_pred', z=0)
                color_kws = elast_kws
                array = np.abs(array)

            elif image_name == 'PINN-het':
                job_name = name.format(
                    example_id=example_id, pde_name='hetero', noise_ratio=noise_ratio, omega=omega, savgol_filter=savgol_filter
                )
                nc_file = f'{expt_name}/{job_name}/{job_name}_elastogram.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='mu_pred', z=0)
                color_kws = elast_kws
                array = np.abs(array)

            im = mre_pinn.visual.imshow(ax, array, **color_kws)
            
    plt.colorbar(im, cax=cbar_ax)
    cbar_ax.set_yticks([0, 5e3, 10e3, 15e3, 20e3])
    cbar_ax.set_yticklabels(['0', '5', '10', '15', '20'])
    cbar_ax.set_ylabel('$\mu$ (kPa)')
            
    return fig

fig = plot_image_grid(['90']) #['50', '60', '70', '80', '90', '100'])
#fig0 = plot_image_grid(['50', '60', '70'])
#fig1 = plot_image_grid(['80', '90', '100'])

for ext in ['png', 'pdf']:
    fig.savefig(f'images/fem_box_wave_grid.{ext}', bbox_inches='tight', dpi=200)
    #fig0.savefig(f'images/fem_box_wave_grid0.{ext}', bbox_inches='tight', dpi=200)
    #fig1.savefig(f'images/fem_box_wave_grid1.{ext}', bbox_inches='tight', dpi=200)

<IPython.core.display.Javascript object>

Loading ../data/BIOQIC/fem_box2/90/wave.nc
Loading ../data/BIOQIC/fem_box2/90/mre.nc
Loading ../data/BIOQIC/fem_box2/90/mre_mask.nc
Loading 2023-2-24_sim_noise2/train_fem_box_90_0e+00_0_60_helmholtz/train_fem_box_90_0e+00_0_60_helmholtz_direct.nc
Loading 2023-2-24_sim_noise2/train_fem_box_90_0e+00_0_60_helmholtz/train_fem_box_90_0e+00_0_60_helmholtz_fem.nc
Loading 2023-2-24_sim_noise2/train_fem_box_90_0e+00_0_60_hetero/train_fem_box_90_0e+00_0_60_hetero_fem.nc
Loading 2023-2-24_sim_noise2/train_fem_box_90_0e+00_0_60_helmholtz/train_fem_box_90_0e+00_0_60_helmholtz_elastogram.nc
Loading 2023-2-24_sim_noise2/train_fem_box_90_0e+00_0_60_hetero/train_fem_box_90_0e+00_0_60_hetero_elastogram.nc


In [74]:
%autoreload
import mre_pinn

# image grid

image_names = [
    'wave image', 'true elasticity', 'FEM reconstruction', 'PINN reconstruction'
]

expt_name = '2023-2-24_sim_noise2' 
name = 'train_fem_box_{example_id}_{noise_ratio:.0e}_{savgol_filter}_{omega}_{pde_name}'

def plot_image_grid(example_ids):
    n_rows = len(example_ids)
    n_cols = len(image_names)
    ax_width = 1.8
    ax_height = ax_width / 0.8
    cbar_width = 0.1
    
    noise_ratio = 0
    savgol_filter = 0
    omega = 60

    fig, axes, cbar_ax = mre_pinn.visual.subplot_grid(
        n_rows, n_cols, ax_height, ax_width, cbar_width, space=0, pad=(0.35,0.55,0.15,0.25)
    )
    for row_idx, example_id in enumerate(example_ids):
        example = mre_pinn.data.MREExample.load_xarrays('../data/BIOQIC/fem_box2', example_id)
        for col_idx, image_name in enumerate(image_names):
            ax = axes[row_idx,col_idx]
            #ax.text(0.1, 0.1, f'{row_idx}, {col_idx}')
            if col_idx == 0 and False:
                ax.set_ylabel(example_id + ' Hz', fontsize='medium')
            if row_idx == 0:
                ax.set_title(image_name, fontsize='large')
            ax.set_yticks([])
            ax.set_xticks([])
            
            if image_name == 'regions':
                array = example.mre_mask.sel(z=0)
                color_kws = mre_pinn.visual.get_color_kws(array)

            elif image_name == 'wave image':
                array = example.wave.sel(component='z', z=0)
                color_kws = wave_kws
                array = array.real

            elif image_name in {'true elasticity', 'elastogram'}:
                array = example.mre.sel(z=0)
                color_kws = elast_kws
                array = np.abs(array)

            elif image_name == 'AHI':
                job_name = name.format(
                    example_id=example_id, pde_name='helmholtz', noise_ratio=noise_ratio, omega=omega, savgol_filter=savgol_filter
                )
                nc_file = f'{expt_name}/{job_name}/{job_name}_direct.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='direct_pred', z=0)
                color_kws = elast_kws
                array = np.abs(array)

            elif image_name == 'FEM reconstruction':
                job_name = name.format(
                    example_id=example_id, pde_name='hetero', noise_ratio=noise_ratio, omega=omega, savgol_filter=savgol_filter
                )
                nc_file = f'{expt_name}/{job_name}/{job_name}_fem.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='fem_pred', z=0)
                color_kws = elast_kws
                array = np.abs(array)

            elif image_name == 'PINN reconstruction':
                job_name = name.format(
                    example_id=example_id, pde_name='hetero', noise_ratio=noise_ratio, omega=omega, savgol_filter=savgol_filter
                )
                nc_file = f'{expt_name}/{job_name}/{job_name}_elastogram.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='mu_pred', z=0)
                color_kws = elast_kws
                array = np.abs(array)

            im = mre_pinn.visual.imshow(ax, array, **color_kws)
            
    plt.colorbar(im, cax=cbar_ax)
    cbar_ax.set_yticks([0, 5e3, 10e3, 15e3, 20e3])
    cbar_ax.set_yticklabels(['0', '5', '10', '15', '20'])
    cbar_ax.set_ylabel('$\mu$ (kPa)')
            
    return fig

fig = plot_image_grid(['90']) #['50', '60', '70', '80', '90', '100'])
#fig0 = plot_image_grid(['50', '60', '70'])
#fig1 = plot_image_grid(['80', '90', '100'])

for ext in ['png', 'pdf']:
    fig.savefig(f'images/F31_pinn_figure.{ext}', bbox_inches='tight', dpi=200)
    #fig0.savefig(f'images/fem_box_wave_grid0.{ext}', bbox_inches='tight', dpi=200)
    #fig1.savefig(f'images/fem_box_wave_grid1.{ext}', bbox_inches='tight', dpi=200)

<IPython.core.display.Javascript object>

Loading ../data/BIOQIC/fem_box2/90/wave.nc
Loading ../data/BIOQIC/fem_box2/90/mre.nc
Loading ../data/BIOQIC/fem_box2/90/mre_mask.nc
Loading 2023-2-24_sim_noise2/train_fem_box_90_0e+00_0_60_hetero/train_fem_box_90_0e+00_0_60_hetero_fem.nc
Loading 2023-2-24_sim_noise2/train_fem_box_90_0e+00_0_60_hetero/train_fem_box_90_0e+00_0_60_hetero_elastogram.nc


In [75]:
fig, ax = plt.subplots(1, 2, figsize=(8.5, 3), width_ratios=(0.7, 0.3), squeeze=False)

ps.results.barplot(
    mmm[mmm.spatial_region.isin(set('2345')) & (mmm.noise_ratio == 0.0) & ~mmm.savgol_filter],
    x='spatial_region',
    y='mu_pred_CTE',
    hue='method',
    ax=ax[0,0]
)
ax[0,0].grid(linestyle=':')
ax[0,0].axes.set_axisbelow(True)
ax[0,0].legend(frameon=True, edgecolor='w', loc='upper right', fontsize='small', ncol=2)

t = mpl.transforms.ScaledTranslation(-0.6, 0, fig.dpi_scale_trans)
ax[0,0].text(
    0, 1, 'A', fontweight='bold', va='bottom', ha='right',
    transform=ax[0,0].transAxes + t
)

ax[0,0].set_ylabel('CTE (%)')
ax[0,0].set_ylim(0, 125)
ax[0,0].yaxis.set_major_formatter(pct_format)

ax[0,0].set_xlabel('spatial region')
ax[0,0].set_xticklabels(['target 1', 'target 2', 'target 3', 'target 4'])

sns.lineplot(
    mmm[mmm.spatial_region.isin(set('2345')) & ~mmm.savgol_filter],
    x='noise_level',
    y='mu_pred_CTE',
    hue='method',
    ax=ax[0,1]
)
ax[0,1].grid(linestyle=':')
ax[0,1].axes.set_axisbelow(True)
ax[0,1].legend(frameon=True, edgecolor='w', fontsize='small', ncol=1, loc='upper right')

t = mpl.transforms.ScaledTranslation(-0.4, 0, fig.dpi_scale_trans)
ax[0,1].text(0, 1, 'B', fontweight='bold', va='bottom', ha='right', transform=ax[0,1].transAxes + t)

ax[0,1].set_ylabel('CTE (%)')
ax[0,1].set_ylim(0, 85)
ax[0,1].yaxis.set_major_formatter(pct_format)

ax[0,1].set_xlabel('noise level (dB)')
ax[0,1].set_xticks([-60, -50, -40, -30, -20, -10])
ax[0,1].set_xticklabels(['-inf', '-50', '-40', '-30', '-20', '-10'])
ax[0,1].set_title('noise robustness', fontsize='large')

sns.despine(fig)
fig.tight_layout()

for ext in ['png', 'pdf']:
    fig.savefig(f'images/fem_box_contrast_plots2.{ext}', bbox_inches='tight', dpi=200)

<IPython.core.display.Javascript object>

In [137]:
%autoreload
import mre_pinn

# image grid

image_names = [
    'wave image', 'AHI', 'FEM-HH', 'FEM-het', 'PINN-HH', 'PINN-het'
]

expt_name = '2023-2-24_sim_noise2' 
name = 'train_fem_box_{example_id}_{noise_ratio:.0e}_{savgol_filter}_{omega}_{pde_name}'

def plot_image_grid(noise_ratios):
    n_rows = len(image_names)
    n_cols = len(noise_ratios)
    ax_height = 1.5
    ax_width = ax_height / 0.8
    cbar_width = 0.2
    
    example_id = 90
    savgol_filter = 0
    omega = 60

    fig, axes, cbar_ax = mre_pinn.visual.subplot_grid(
        n_rows, n_cols, ax_height, ax_width, cbar_width, space=0, pad=(0.45,0,0.15,0.25)
    )
    for col_idx, noise_ratio in enumerate(noise_ratios):
        example = mre_pinn.data.MREExample.load_xarrays('../data/BIOQIC/fem_box2', example_id)
        if noise_ratio > 0:
            example.add_gaussian_noise(noise_ratio)
        for row_idx, image_name in enumerate(image_names):
            ax = axes[row_idx,col_idx]
            #ax.text(0.1, 0.1, f'{row_idx}, {col_idx}')
            
            if col_idx == 0:
                ax.set_ylabel(image_name)
                              
            if row_idx + 1 == n_rows:
                if noise_ratio > 0:
                    noise_level = str(int(10 * np.log10(noise_ratio))) + ' dB'
                else:
                    noise_level = 'no noise'
                print('HERE')
                ax.set_xlabel(f'{noise_level}')
    
            ax.set_yticks([])
            ax.set_xticks([])
            
            if image_name == 'regions':
                array = example.mre_mask.sel(z=0)
                color_kws = mre_pinn.visual.get_color_kws(array)

            elif image_name == 'wave image':
                array = example.wave.sel(component='z', z=0)
                color_kws = wave_kws
                array = array.real

            elif image_name in {'ground truth', 'elastogram'}:
                array = example.mre.sel(z=0)
                color_kws = elast_kws
                array = np.abs(array)

            elif image_name == 'AHI':
                job_name = name.format(
                    example_id=example_id, pde_name='helmholtz', noise_ratio=noise_ratio, omega=omega, savgol_filter=savgol_filter
                )
                nc_file = f'{expt_name}/{job_name}/{job_name}_direct.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='direct_pred', z=0)
                color_kws = elast_kws
                array = np.abs(array)
                
            elif image_name == 'FEM-HH':
                job_name = name.format(
                    example_id=example_id, pde_name='helmholtz', noise_ratio=noise_ratio, omega=omega, savgol_filter=savgol_filter
                )
                nc_file = f'{expt_name}/{job_name}/{job_name}_fem.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='fem_pred', z=0)
                color_kws = elast_kws
                array = np.abs(array)

            elif image_name == 'FEM-het':
                job_name = name.format(
                    example_id=example_id, pde_name='hetero', noise_ratio=noise_ratio, omega=omega, savgol_filter=savgol_filter
                )
                nc_file = f'{expt_name}/{job_name}/{job_name}_fem.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='fem_pred', z=0)
                color_kws = elast_kws
                array = np.abs(array)

            elif image_name == 'PINN-HH':
                job_name = name.format(
                    example_id=example_id, pde_name='helmholtz', noise_ratio=noise_ratio, omega=omega, savgol_filter=savgol_filter
                )
                nc_file = f'{expt_name}/{job_name}/{job_name}_elastogram.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='mu_pred', z=0)
                color_kws = elast_kws
                array = np.abs(array)

            elif image_name == 'PINN-het':
                job_name = name.format(
                    example_id=example_id, pde_name='hetero', noise_ratio=noise_ratio, omega=omega, savgol_filter=savgol_filter
                )
                nc_file = f'{expt_name}/{job_name}/{job_name}_elastogram.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='mu_pred', z=0)
                color_kws = elast_kws
                array = np.abs(array)

            im = mre_pinn.visual.imshow(ax, array.T[::-1], **color_kws)
            
    plt.colorbar(im, cax=cbar_ax)
    cbar_ax.set_yticks([0, 5e3, 10e3, 15e3, 20e3])
    cbar_ax.set_yticklabels(['0 kPa', '5', '10', '15', '20 kPa'])
    #cbar_ax.set_ylabel('$\mu$ (kPa)')
            
    return fig

fig = plot_image_grid([0, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1])

for ext in ['png', 'pdf']:
    fig.savefig(f'images/fem_box_noise_grid.{ext}', bbox_inches='tight', dpi=200)

<IPython.core.display.Javascript object>

Loading ../data/BIOQIC/fem_box2/90/wave.nc
Loading ../data/BIOQIC/fem_box2/90/mre.nc
Loading ../data/BIOQIC/fem_box2/90/mre_mask.nc
Loading 2023-2-24_sim_noise2/train_fem_box_90_0e+00_0_60_helmholtz/train_fem_box_90_0e+00_0_60_helmholtz_direct.nc
Loading 2023-2-24_sim_noise2/train_fem_box_90_0e+00_0_60_helmholtz/train_fem_box_90_0e+00_0_60_helmholtz_fem.nc
Loading 2023-2-24_sim_noise2/train_fem_box_90_0e+00_0_60_hetero/train_fem_box_90_0e+00_0_60_hetero_fem.nc
Loading 2023-2-24_sim_noise2/train_fem_box_90_0e+00_0_60_helmholtz/train_fem_box_90_0e+00_0_60_helmholtz_elastogram.nc
HERE
Loading 2023-2-24_sim_noise2/train_fem_box_90_0e+00_0_60_hetero/train_fem_box_90_0e+00_0_60_hetero_elastogram.nc
Loading ../data/BIOQIC/fem_box2/90/wave.nc
Loading ../data/BIOQIC/fem_box2/90/mre.nc
Loading ../data/BIOQIC/fem_box2/90/mre_mask.nc
Loading 2023-2-24_sim_noise2/train_fem_box_90_1e-05_0_60_helmholtz/train_fem_box_90_1e-05_0_60_helmholtz_direct.nc
Loading 2023-2-24_sim_noise2/train_fem_box_90_1e-05

In [136]:
%autoreload
import mre_pinn

expt_name = '2023-2-24_sim_noise2' 
name = 'train_fem_box_{example_id}_{noise_ratio:.0e}_{savgol_filter}_{omega}_{pde_name}'

def plot_image_grid():
    n_rows = 1
    n_cols = len(dataset) + 1
    ax_height = 1.5
    ax_width = ax_height / 0.8
    cbar_width = 0.2
    
    example_id = 90
    savgol_filter = 0
    omega = 60

    fig, axes, cbar_ax = mre_pinn.visual.subplot_grid(
        n_rows, n_cols, ax_height, ax_width, cbar_width, space=0, pad=(0.45,0,0.15,0.25)
    )

    row_idx = 0
    for col_idx in range(n_cols):
        
        if col_idx == 0: # ground truth
            example_id = 50
            image_name = 'ground truth'
        else: # wave images
            example_id = 50 + (col_idx - 1) * 10
            image_name = 'wave image'
        
        example = mre_pinn.data.MREExample.load_xarrays('../data/BIOQIC/fem_box2', example_id)

        ax = axes[row_idx,col_idx]
        #ax.text(0.1, 0.1, f'{row_idx}, {col_idx}')
        ax.set_yticks([])
        ax.set_xticks([])
            
        if image_name == 'wave image':
            array = example.wave.sel(component='z', z=0)
            color_kws = wave_kws
            array = array.real
            ax.set_title(f'{example_id} Hz')

        elif image_name in {'ground truth', 'elastogram'}:
            array = example.mre.sel(z=0)
            color_kws = elast_kws
            array = np.abs(array)
            ax.set_title('true elasticity')

        im = mre_pinn.visual.imshow(ax, array.T[::-1], **color_kws)
            
    plt.colorbar(im, cax=cbar_ax)
    #cbar_ax.set_yticks([0, 5e3, 10e3, 15e3, 20e3])
    cbar_ax.set_yticklabels(['-0.01 mm', '0', '0.01 mm'])
            
    return fig

fig = plot_image_grid()

for ext in ['png', 'pdf']:
    fig.savefig(f'images/fem_box_data_grid.{ext}', bbox_inches='tight', dpi=200)

<IPython.core.display.Javascript object>

Loading ../data/BIOQIC/fem_box2/50/wave.nc
Loading ../data/BIOQIC/fem_box2/50/mre.nc
Loading ../data/BIOQIC/fem_box2/50/mre_mask.nc
Loading ../data/BIOQIC/fem_box2/50/wave.nc
Loading ../data/BIOQIC/fem_box2/50/mre.nc
Loading ../data/BIOQIC/fem_box2/50/mre_mask.nc
Loading ../data/BIOQIC/fem_box2/60/wave.nc
Loading ../data/BIOQIC/fem_box2/60/mre.nc
Loading ../data/BIOQIC/fem_box2/60/mre_mask.nc
Loading ../data/BIOQIC/fem_box2/70/wave.nc
Loading ../data/BIOQIC/fem_box2/70/mre.nc
Loading ../data/BIOQIC/fem_box2/70/mre_mask.nc
Loading ../data/BIOQIC/fem_box2/80/wave.nc
Loading ../data/BIOQIC/fem_box2/80/mre.nc
Loading ../data/BIOQIC/fem_box2/80/mre_mask.nc
Loading ../data/BIOQIC/fem_box2/90/wave.nc
Loading ../data/BIOQIC/fem_box2/90/mre.nc
Loading ../data/BIOQIC/fem_box2/90/mre_mask.nc


  cbar_ax.set_yticklabels(['-0.01 mm', '0', '0.01 mm'])



Loading ../data/BIOQIC/fem_box2/100/wave.nc
Loading ../data/BIOQIC/fem_box2/100/mre.nc
Loading ../data/BIOQIC/fem_box2/100/mre_mask.nc
