In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 1
!hostname
!pwd

dv001.ib.bridges2.psc.edu
/ocean/projects/asc170022p/mtragoza/mre-pinn/IPMI-2023


In [2]:
import sys, os
import numpy as np
import pandas as pd

sys.path.append('..')
%aimport mre_pinn

sys.path.append('../../param_search')
%aimport param_search
ps = param_search

Using backend: pytorch



# IPMI 2023 cohort experiment

In [3]:
%pwd

'/ocean/projects/asc170022p/mtragoza/mre-pinn/IPMI-2023'

In [25]:
# define the job template and name format

template = '''\
#!/bin/bash
#SBATCH --job-name={job_name}
#SBATCH --account=asc170022p
#SBATCH --partition=GPU-shared
#SBATCH --gres=gpu:1
#SBATCH --time=48:00:00
#SBATCH -o %J.stdout
#SBATCH -e %J.stderr
#SBATCH --mail-type=all

hostname
pwd
source activate MRE-PINN

python ../../../train_pino.py \\
    --xarray_dir {xarray_dir} \\
    --frequency {frequency} \\
    --pde_name {pde_name} \\
    --pde_init_weight {pde_init_weight} \\
    --pde_loss_weight {pde_loss_weight} \\
    --pde_warmup_iters {pde_warmup_iters}000 \\
    --pde_step_iters {pde_step_iters}000 \\
    --learning_rate {learning_rate} \\
    --conditional {conditional} \\
    --parallel {parallel} \\
    --save_prefix {job_name}    
'''
name = 'train_{data_name}_{conditional}_{parallel}_{pde_name}_{learning_rate:.0e}_{pde_warmup_iters}k_{pde_step_iters}k'

# define the parameter space

cohort_space = ps.ParamSpace(
    data_name='cohort',
    xarray_dir='../../../data/NAFLD/v3',
    frequency=40,
    pde_name=['helmholtz', 'hetero'],
    pde_init_weight=1e-18,
    pde_loss_weight=1e-16,
    pde_warmup_iters=[10, 20],
    pde_step_iters=[5, 10],
    learning_rate=[1e-4, 5e-5, 2e-5, 1e-5],
) * (
    ps.ParamSpace(
        conditional=[0, 1],
        parallel=0
    ) + ps.ParamSpace(
        conditional=1,
        parallel=1
    )
)

param_space = cohort_space
len(param_space)

96

In [26]:
name.format(**param_space[10])

'train_cohort_1_0_helmholtz_1e-05_10k_5k'

In [29]:
%autoreload
expt_name = '2022-11-28_cohort_init'

jobs = ps.submit(template, name, param_space, work_dir=expt_name, verbose=True)
jobs.to_csv(f'{expt_name}.jobs')

100%|██████████| 96/96 [00:00<00:00, 141.75it/s]
[13341500, 13341501, 13341502, 13341503, 13341504, 13341505, 13341506, 13341507, 13341508, 13341509, 13341510, 13341511, 13341512, 13341513, 13341514, 13341515, 13341516, 13341517, 13341518, 13341519, 13341520, 13341521, 13341522, 13341523, 13341524, 13341525, 13341526, 13341527, 13341528, 13341529, 13341530, 13341531, 13341532, 13341533, 13341534, 13341535, 13341536, 13341537, 13341538, 13341539, 13341540, 13341541, 13341542, 13341543, 13341544, 13341545, 13341546, 13341547, 13341548, 13341549, 13341550, 13341551, 13341552, 13341553, 13341554, 13341555, 13341556, 13341557, 13341558, 13341559, 13341560, 13341561, 13341562, 13341563, 13341564, 13341565, 13341566, 13341567, 13341568, 13341569, 13341570, 13341571, 13341572, 13341573, 13341574, 13341575, 13341576, 13341577, 13341578, 13341579, 13341580, 13341581, 13341582, 13341583, 13341584, 13341585, 13341586, 13341587, 13341588, 13341589, 13341590, 13341591, 13341592, 13341593, 13341594, 

In [None]:
status_cols = ['job_name', 'job_state', 'node_id', 'runtime', 'stdout', 'stderr']
status = ps.status(jobs)
status[status_cols]

In [None]:
errors = status[status.stderr != 'Using backend: pytorch\n\n']
errors

In [None]:
metrics = ps.metrics(jobs)

# did all models train to 100k iterations?
#assert (metrics.groupby('job_name')['iteration'].max() == 100e3).all()

# get the final test evaluations
metrics = metrics[metrics.iteration == 100e3]

param_cols = ['pde_name', 'example_id'] # experimental parameters
index_cols = ['variable_name', 'spatial_frequency_bin', 'spatial_region'] # metric identifiers
metric_cols = ['MSAV', 'PSD', 'MAV'] # metric values

metrics = metrics.groupby(param_cols + index_cols, sort=False)[metric_cols].mean()
metrics = metrics.unstack(level=[len(param_cols)])

def metric_map(t):
    metric_name, var_name = t
    new_col_name = f'{var_name}_{metric_name}'
    new_col_name = new_col_name.replace('diff_MSAV', 'pred_MSAE')
    new_col_name = new_col_name.replace('pde_diff_MSAV', 'PDE_MSAE')
    new_col_name = new_col_name.replace('diff_MAV', 'pred_MAD')
    return new_col_name

metrics.columns = [metric_map(t) for t in metrics.columns.to_flat_index()]
metrics

In [None]:
m = metrics.reset_index()

#fig = ps.plot(
#    m[m.spatial_region == 'all'],
#    x=param_cols,
#    y=['u_pred_MSAE', 'mu_pred_MSAE', 'direct_pred_MSAE', 'fem_pred_MSAE'],
#    height=2.5,
#    width=2.5,
#    legend=False,
#    tight=True
#)

In [None]:
m.columns

In [None]:
m['method'] = 'PINN_' + m['pde_name']

direct_m = m[m.pde_name == 'helmholtz'].copy()
direct_m['method'] = 'direct_helmholtz'
for col in m.columns:
    if col.startswith('direct'):
        direct_m[col.replace('direct', 'mu')] = m[col]
        
hh_fem_m = m[m.pde_name == 'helmholtz'].copy()
hh_fem_m['method'] = 'FEM_helmholtz'
for col in m.columns:
    if col.startswith('fem'):
        hh_fem_m[col.replace('fem', 'mu')] = m[col]

ht_fem_m = m[m.pde_name == 'hetero'].copy()
ht_fem_m['method'] = 'FEM_hetero'
for col in m.columns:
    if col.startswith('fem'):
        ht_fem_m[col.replace('fem', 'mu')] = m[col]
        
mm = pd.concat([direct_m, hh_fem_m, ht_fem_m, m])

In [None]:
mm['mu_pred_MSAE_relative'] = mm['mu_pred_MSAE'] / mm['mu_true_MSAV']
mm['mu_pred_MAD_relative'] = mm['mu_pred_MAD'] / mm['mu_true_MAV']

In [None]:
%autoreload
import seaborn as sns
import matplotlib.pyplot as plt

colors = ps.results.get_color_palette(10, type='deep', min_val=0)
blue, orange, green, red, purple, brown, pink, gray, yellow, cyan = colors

colors = [blue, yellow, red, cyan, green]
sns.set_palette(colors)
colors = sns.color_palette()
colors

In [None]:
fig = ps.plot(
    mm[~mm.spatial_region.isin({'0.0', 'all'})],
    x='method',
    y='mu_pred_MAD_relative',
    height=3,
    width=7,
    legend=True,
    tight=True,
    plot_func=ps.results.barplot
)
for i, patch in enumerate(fig.axes[0].patches):
    plt.setp(patch, facecolor=colors[i])

In [None]:
fig = ps.plot(
    mm[~mm.spatial_region.isin({'0.0', 'all'})],
    x='spatial_region',
    y='mu_pred_MAD_relative',
    hue='method',
    height=3,
    width=7,
    legend=True,
    tight=True,
    plot_func=ps.results.barplot
)

In [None]:
fig = ps.plot(
    mm[~mm.spatial_region.isin({'0.0', 'all'})],
    x='example_id',
    y='mu_pred_MAD_relative',
    hue='method',
    height=3,
    width=90,
    legend=True,
    tight=True,
    plot_func=ps.results.barplot
)

In [None]:
fig = ps.plot(
    mm[(mm.spatial_region == 'all') & (mm.spatial_frequency_bin == 'all')],
    x='method',
    y='mu_pred_MSAE_relative',
    height=7,
    width=7,
    legend=True,
    tight=True,
    plot_func=ps.results.barplot
)
fig.axes[0].set_yscale('log')
for i, patch in enumerate(fig.axes[0].patches):
    plt.setp(patch, facecolor=colors[i])

In [None]:
%autoreload
import mre_pinn

image_names = ['wave image', 'ground truth', 'direct_helmholtz', 'FEM_helmholtz', 'FEM_hetero', 'PINN_helmholtz', 'PINN_hetero']

def plot_image_grid(example_ids):
    n_rows = len(example_ids)
    n_cols = len(image_names)
    ax_width = 1.2
    ax_height = ax_width

    fig, axes, cbar_ax = mre_pinn.visual.subplot_grid(n_rows, n_cols, ax_height, ax_width, space=0, pad=(0.35,0.15,0.15,0.25))
    for row_idx, example_id in enumerate(example_ids):
        example = mre_pinn.data.MREExample.load_xarrays('../data/NAFLD/v3', example_id)
        for col_idx, image_name in enumerate(image_names):
            ax = axes[row_idx,col_idx]
            #ax.text(0.1, 0.1, f'{row_idx}, {col_idx}')
            if col_idx == 0:
                ax.set_ylabel(example_id, fontsize='medium')
            if row_idx == 0:
                ax.set_title(image_name, fontsize='small')
            ax.set_yticks([])
            ax.set_xticks([])

            if image_name == 'wave image':
                array = example.wave[...,0]
                color_kws = mre_pinn.visual.get_color_kws(array)
                color_kws['vmin'] = -color_kws['vmax']
                array = array.real * example.mre_mask.values[...,0]

            elif image_name == 'ground truth':
                array = example.mre[...,0]
                color_kws = mre_pinn.visual.get_color_kws(array)
                color_kws['vmin'] = -color_kws['vmax']
                array = np.abs(array) * example.mre_mask.values[...,0]

            elif image_name == 'direct_helmholtz':
                nc_file = f'2022-11-26_patient_hetero2/train_patient_{example_id}_helmholtz/train_patient_{example_id}_helmholtz_direct.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='direct_pred')[...,0]
                color_kws = mre_pinn.visual.get_color_kws(array)
                color_kws['vmin'] = -color_kws['vmax']
                array = np.abs(array)

            elif image_name == 'FEM_helmholtz':
                nc_file = f'2022-11-26_patient_hetero2/train_patient_{example_id}_helmholtz/train_patient_{example_id}_helmholtz_fem.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='fem_pred')[...,0]
                color_kws = mre_pinn.visual.get_color_kws(array)
                color_kws['vmin'] = -color_kws['vmax']
                array = np.abs(array)

            elif image_name == 'FEM_hetero':
                nc_file = f'2022-11-23_patient_init/train_patient_{example_id}_hetero/train_patient_{example_id}_hetero_fem.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='fem_pred')[...,0]
                color_kws = mre_pinn.visual.get_color_kws(array)
                color_kws['vmin'] = -color_kws['vmax']
                array = np.abs(array)

            elif image_name == 'PINN_helmholtz':
                nc_file = f'2022-11-26_patient_hetero2/train_patient_{example_id}_helmholtz/train_patient_{example_id}_helmholtz_elastogram.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='mu_pred')[...,0]
                color_kws = mre_pinn.visual.get_color_kws(array)
                color_kws['vmin'] = -color_kws['vmax']
                array = np.abs(array)

            elif image_name == 'PINN_hetero':
                nc_file = f'2022-11-23_patient_init/train_patient_{example_id}_hetero/train_patient_{example_id}_hetero_elastogram.nc'
                array = mre_pinn.data.dataset.load_xarray_file(nc_file).sel(variable='mu_pred')[...,0]
                color_kws = mre_pinn.visual.get_color_kws(array)
                color_kws['vmin'] = -color_kws['vmax']
                array = np.abs(array)

            mre_pinn.visual.imshow(ax, array, **color_kws)

plot_image_grid(['0006', '0020', '0024', '0029'])
plot_image_grid(['0043', '0047', '0126', '0135'])