In [1]:
import h5py
import os

import numpy as np
from matplotlib import pyplot as plt
plt.rcParams.update(plt.rcParamsDefault)
import seaborn as sns
import pandas as pd

from interlacer import utils
from scripts import training_config
from visualizations import visualization_lib, plotting_utils

## Undersampling: Ablations + Comparison with Other Methods

In [2]:
patch_inds = 50, 50
patch_n = 64

def load_ref_imgs(h5_path, inputs, x=50, y=50, n=64):
    f = h5py.File(h5_path, 'r')
    
    if(inputs):
        key = 'inputs'
    else:
        key = 'outputs'
        
    k = f[key][()]
    img = utils.convert_to_image_domain(k)
    patch = img[:,x:x+n,y:y+n,:]
    
    return img, k, patch

def load_imgs(h5_path, x=50, y=50, n=64):
    f = h5py.File(h5_path, 'r')

    img = f['img_outputs'][()]
    k = f['freq_outputs'][()]
    patch = img[:,x:x+n,y:y+n,:]
    
    return img, k, patch

def load_ssims(h5_path):
    f = h5py.File(h5_path, 'r')
    ssim = f['ssim'][()]
    return ssim
    
    return img, k, patch

def load_psnrs(h5_path):
    f = h5py.File(h5_path, 'r')
    psnr = f['psnr'][()]
    return psnr
    
    return img, k, patch

def plot_imgs(axes, col, img, k, patch, gt_patch, idx, title=None):
    
    axes[0][col].set_title(title)
    plotting_utils.plot_img(img, idx, ax=axes[0][col])
    plotting_utils.plot_img(patch, idx, ax=axes[1][col],)
    plotting_utils.plot_img_diff(patch, gt_patch, idx, ax=axes[2][col],)
    plotting_utils.plot_k(k, idx, ax=axes[3][col])
    

In [3]:
def remap_title(title):
    title_dict = {'CONV_RESIDUAL FREQ': 'Frequency',
                 'CONV_RESIDUAL IMAGE': 'Image',
                 'ALTERNATING_RESIDUAL FREQ': 'Alternating',
                 'INTERLACER_RESIDUAL FREQ': 'Interleaved',}
    return title_dict[title]

fig, axes = plt.subplots(4, 9, figsize=(18,8), gridspec_kw = {'wspace':0.05, 'hspace':0.05})
fig.subplots_adjust(hspace=1, wspace=1)
exp_dir = '/data/vision/polina/scratch/nmsingh/dev/fouriernetworks/training/melba/undersample_ablations'
baseline_dirs = os.listdir(exp_dir)
baseline_dirs.sort()
baseline_dirs = [baseline_dirs[i] for i in [1,2,0,3]]

ref_h5_path = '/data/vision/polina/scratch/nmsingh/dev/fouriernetworks/testsets/brain_undersample_test.h5'
ind = 4

gt_img, gt_k, gt_patch = load_ref_imgs(ref_h5_path, inputs=False)
plot_imgs(axes, 0, gt_img, gt_k, gt_patch, gt_patch, ind, title='Ground Truth')

img, k, patch = load_ref_imgs(ref_h5_path, inputs=True)
plot_imgs(axes, 1, img, k, patch, gt_patch, ind, title='Input')
input_k = k

for exp_i in range(2,len(baseline_dirs)+2):
    exp = baseline_dirs[exp_i-2]
    model_dir = os.path.join(exp_dir,exp)
    
    config_path = visualization_lib.get_config_path(model_dir)
    config = training_config.TrainingConfig(os.path.join(model_dir, config_path))
    config.read_config()
    title = remap_title(config.architecture+' '+config.input_domain)
    h5_path = os.path.join(model_dir,'brain_undersample_test_output.h5')
    
    img, k, patch = load_imgs(h5_path)
    i = exp_i if(exp_i<4) else exp_i+3
    plot_imgs(axes, i, img, k, patch, gt_patch, ind, title=title)

plt.subplots_adjust(wspace=None, hspace=None)

unet_h5_path = '/data/vision/polina/scratch/nmsingh/dev/fouriernetworks/training/melba/baselines_matched/unet_brain_undersample_test_output.h5'
img, k, patch = load_imgs(unet_h5_path)
plot_imgs(axes, 4, img, k, patch, gt_patch, ind, 'U-Net')    
#plt.tight_layout()

cascade_h5_path = '/data/vision/polina/scratch/nmsingh/dev/fouriernetworks/training/melba/baselines_matched/cascade_brain_undersample_test_output.h5'
img, k, patch = load_imgs(cascade_h5_path)
plot_imgs(axes, 5, img, k, patch, gt_patch, ind, 'CascadeNet')    
plt.savefig('methods_comparison.png',bbox_inches='tight')
cascade_k = k

kiki_h5_path = '/data/vision/polina/scratch/nmsingh/dev/fouriernetworks/training/melba/baselines_matched/kiki_brain_undersample_test_output.h5'
img, k, patch = load_imgs(kiki_h5_path)
plot_imgs(axes, 6, img, k, patch, gt_patch, ind, 'KIKI')    

plt.savefig('undersample_sota.pdf',bbox_inches='tight')
#plt.tight_layout()

In [4]:
network_list = ['Freq Conv', 'Image Conv', 'Alternating', 'Interleaved', 'U-Net', 'CascadeNet', 'KIKI']

ssims = []
for exp_i in range(2,len(baseline_dirs)+2):
    exp = baseline_dirs[exp_i-2]
    model_dir = os.path.join(exp_dir,exp)
    
    config_path = visualization_lib.get_config_path(model_dir)
    config = training_config.TrainingConfig(os.path.join(model_dir, config_path))
    config.read_config()
    title = remap_title(config.architecture+' '+config.input_domain)
    h5_path = os.path.join(model_dir,'brain_undersample_test_output.h5')
    
    ssim = load_ssims(h5_path)
    ssims.append(ssim)

ssims.append(load_ssims(unet_h5_path))
ssims.append(load_ssims(cascade_h5_path))
ssims.append(load_ssims(kiki_h5_path))

ssims = np.vstack(ssims)

In [5]:
psnrs = []
for exp_i in range(2,len(baseline_dirs)+2):
    exp = baseline_dirs[exp_i-2]
    model_dir = os.path.join(exp_dir,exp)
    
    config_path = visualization_lib.get_config_path(model_dir)
    config = training_config.TrainingConfig(os.path.join(model_dir, config_path))
    config.read_config()
    title = remap_title(config.architecture+' '+config.input_domain)
    h5_path = os.path.join(model_dir,'brain_undersample_test_output.h5')
    
    psnr = load_psnrs(h5_path)
    psnrs.append(psnr)

psnrs.append(load_psnrs(unet_h5_path))
psnrs.append(load_psnrs(cascade_h5_path))
psnrs.append(load_psnrs(kiki_h5_path))

psnrs = np.vstack(psnrs)

In [6]:
ssims_df = pd.DataFrame(ssims,index=network_list)
psnrs_df = pd.DataFrame(psnrs,index=network_list)

In [7]:
def dec_2(data):
    return '{:.2f}'.format(data)

stats_df = pd.DataFrame(index=network_list, columns=['SSIM','PSNR'])
for i in range(len(network_list)):
    ssim_cutoff = ssims_df.T.mean().to_numpy()
    ssim_cutoff.sort()
    ssim_cutoff = ssim_cutoff[-2]
    if(ssims_df.T.mean()[i]>=ssim_cutoff):
        prec = '\textbf{'
        end = '}'
    else:
        prec = ''
        end = ''
    stats_df['SSIM'][i] = prec+dec_2(ssims_df.T.mean()[i])+' $\pm$ '+dec_2(ssims_df.T.std()[i])+end

    psnr_cutoff = psnrs_df.T.mean().to_numpy()
    psnr_cutoff.sort()
    psnr_cutoff = psnr_cutoff[-2]    
    
    if(psnrs_df.T.mean()[i]>=psnr_cutoff):
        prec = '\textbf{'
        end = '}'
    else:
        prec = ''
        end = ''
    stats_df['PSNR'][i] = prec+dec_2(psnrs_df.T.mean()[i])+' $\pm$ '+dec_2(psnrs_df.T.std()[i])+end

cols = ['Freq Conv', 'Image Conv','U-Net','CascadeNet','KIKI','Alternating','Interleaved']
stats_df = stats_df.T[cols]
    

In [8]:
print(stats_df.to_latex(escape=False).replace('llllllll','cccccccc'))

\begin{tabular}{cccccccc}
\toprule
{} &         Freq Conv &        Image Conv &             U-Net &        CascadeNet &              KIKI &                Alternating &                Interleaved \\
\midrule
SSIM &   0.83 $\pm$ 0.06 &   0.90 $\pm$ 0.03 &   0.85 $\pm$ 0.06 &   0.90 $\pm$ 0.04 &   0.93 $\pm$ 0.03 &   \textbf{0.95 $\pm$ 0.02} &   \textbf{0.95 $\pm$ 0.02} \\
PSNR &  31.08 $\pm$ 3.58 &  32.89 $\pm$ 3.05 &  30.56 $\pm$ 3.60 &  34.62 $\pm$ 4.64 &  37.54 $\pm$ 3.67 &  \textbf{38.98 $\pm$ 3.65} &  \textbf{37.93 $\pm$ 3.74} \\
\bottomrule
\end{tabular}



In [9]:
colors = ['#69f97a','#95a4ee','#fcb335','#d499ff','#5900b3']
hue_order = ['U-Net', 'CascadeNet', 'KIKI','Alternating', 'Interleaved']
this_cmap = sns.color_palette(colors,len(colors))
sns.set_palette(this_cmap)   

fig,ax = plt.subplots(1,1,figsize=(6,6))
legend = ['Freq Conv', 'Image Conv', 'Alternating', 'Interleaved', 'U-Net', 'CascadeNet', 'KIKI']
net_inds = [4,5,6,2,3]
sort_inds = np.argsort(ssims[3,:])
for i in net_inds:
    sns.scatterplot(np.arange(0,100),(0*ssims[2,:]+ssims[i,:])[sort_inds],palette=this_cmap,hue_order=hue_order,edgecolor=None,s=10)
plt.legend(hue_order,loc='lower right',framealpha=0.0,handletextpad=0.,markerscale=3,bbox_to_anchor=(0.5, 0, 0.5, 0.5))

plt.xlabel('Subject')
plt.ylabel('SSIM')
plt.ylim([0.6,1])
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
plt.savefig('undersample_scatter.pdf',bbox_inches='tight')


## Ablations

In [10]:
def plot_ablations(exp_dir, ref_h5_path, h5_results_fn, ind, x, y, n):
    fig, axes = plt.subplots(4, 6, figsize=(12,8), gridspec_kw = {'wspace':0.05, 'hspace':0.05})
    fig.subplots_adjust(hspace=1, wspace=1)

    baseline_dirs = os.listdir(exp_dir)
    baseline_dirs.sort()
    baseline_dirs = [baseline_dirs[i] for i in [1,2,0,3]]
    
    gt_img, gt_k, gt_patch = load_ref_imgs(ref_h5_path, inputs=False, x=x, y=y, n=n)
    plot_imgs(axes, 0, gt_img, gt_k, gt_patch, gt_patch, ind, title='Ground Truth')

    img, k, patch = load_ref_imgs(ref_h5_path, inputs=True, x=x, y=y, n=n)
    plot_imgs(axes, 1, img, k, patch, gt_patch, ind, title='Input')
    input_k = k

    for exp_i in range(2,len(baseline_dirs)+2):
        exp = baseline_dirs[exp_i-2]
        model_dir = os.path.join(exp_dir,exp)

        config_path = visualization_lib.get_config_path(model_dir)
        config = training_config.TrainingConfig(os.path.join(model_dir, config_path))
        config.read_config()
        title = remap_title(config.architecture+' '+config.input_domain)
        
        h5_path = os.path.join(model_dir,h5_results_fn)

        img, k, patch = load_imgs(h5_path, x=x, y=y, n=n)
        plot_imgs(axes, exp_i, img, k, patch, gt_patch, ind, title=title)

    plt.subplots_adjust(wspace=None, hspace=None)

In [11]:
colors = ['#ff9999','#99d6ff','#d499ff','#5900b3']
this_cmap = sns.color_palette(colors,len(colors))
sns.set_palette(this_cmap)

In [12]:
def plot_ablations_scatter(exp_dir, h5_results_fn,ax=None,ylabel=False,show_legend=False,exp_title=None):
    baseline_dirs = os.listdir(exp_dir)
    baseline_dirs.sort()
    baseline_dirs = [baseline_dirs[i] for i in [1,2,0,3]]

    ssims = []
    for exp_i in range(len(baseline_dirs)):
        exp = baseline_dirs[exp_i]
        model_dir = os.path.join(exp_dir,exp)

        config_path = visualization_lib.get_config_path(model_dir)
        config = training_config.TrainingConfig(os.path.join(model_dir, config_path))
        config.read_config()
        title = remap_title(config.architecture+' '+config.input_domain)
        h5_path = os.path.join(model_dir, h5_results_fn)

        ssim = load_ssims(h5_path)
        ssims.append(ssim)


    hue_order = ['Frequency','Image','Alternating','Interleaved']

        
    legend = ['Freq Conv', 'Image Conv', 'Alternating', 'Interleaved']
    sort_inds = np.argsort(ssims[3])
        
    if(ax is None):
        plt.figure(figsize=(20,12))

        for i in range(4):
            plt.scatter(np.arange(0,100),ssims[i][sort_inds])
        plt.legend([legend[i] for i in range(4)])

        plt.xlabel('Subject')
        plt.ylabel('SSIM')
        
    else:
        for i in range(4):
            if(i<3):
                sns.scatterplot(range(100),ssims[i][sort_inds],label=hue_order[i],ax=ax,edgecolor=None,s=10)
            else:
                sp = sns.scatterplot(range(100),ssims[i][sort_inds],label=hue_order[i],ax=ax, palette=this_cmap,hue_order=hue_order,edgecolor=None,s=10)

        ax.set_xlabel('Subject')
        
        if(ylabel):
            ax.set_ylabel('SSIM')
        else:
            ax.set_ylabel('')
            
        if(show_legend):
            handles, labels = ax.get_legend_handles_labels()
            ax.get_legend().set_title('') 

            ax.legend(handles,labels,framealpha=0.0,handletextpad=0.,markerscale=3,bbox_to_anchor=(0.6, 0, 0.5, 0.5))
        else:
            sp.legend_.remove()


        ax.set_ylim(bottom=0.5)
        ax.set_title(exp_title)
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        #ax.tick_params(labelsize=16)

        

In [None]:
undersample_exp_dir = '/data/vision/polina/scratch/nmsingh/dev/fouriernetworks/training/melba/undersample_ablations'
undersample_ref_h5_path = '/data/vision/polina/scratch/nmsingh/dev/fouriernetworks/testsets/brain_undersample_test.h5'
undersample_h5_results_fn = 'brain_undersample_test_output.h5'

ind = 78
x=110
y=110
n=84

plot_ablations(undersample_exp_dir, undersample_ref_h5_path, undersample_h5_results_fn, ind, x, y, n)
plt.savefig('undersample_imgs.pdf',bbox_inches='tight')

In [None]:
motion_exp_dir = '/data/vision/polina/scratch/nmsingh/dev/fouriernetworks/training/melba/motion_ablations'
motion_ref_h5_path = '/data/vision/polina/scratch/nmsingh/dev/fouriernetworks/testsets/brain_motion_test.h5'
motion_h5_results_fn = 'brain_motion_test_output.h5'

ind = 1
x=110
y=110
n=84

plot_ablations(motion_exp_dir, motion_ref_h5_path, motion_h5_results_fn, ind, x, y, n)
plt.savefig('motion_imgs.pdf',bbox_inches='tight')

In [None]:
noise_exp_dir = '/data/vision/polina/scratch/nmsingh/dev/fouriernetworks/training/melba/noise_ablations'
noise_ref_h5_path = '/data/vision/polina/scratch/nmsingh/dev/fouriernetworks/testsets/brain_noise_test.h5'
noise_h5_results_fn = 'brain_noise_test_output.h5'


ind = 3
x=110
y=110
n=84

plot_ablations(noise_exp_dir, noise_ref_h5_path, noise_h5_results_fn, ind, x, y, n)
plt.savefig('noise_imgs.pdf',bbox_inches='tight')

In [None]:
us_motion_exp_dir = '/data/vision/polina/scratch/nmsingh/dev/fouriernetworks/training/melba/undersample_motion_ablations'
us_motion_ref_h5_path = '/data/vision/polina/scratch/nmsingh/dev/fouriernetworks/testsets/brain_undersample_motion_test.h5'
us_motion_h5_results_fn = 'brain_undersample_motion_test_output.h5'


ind = 3
x=110
y=110
n=84

plot_ablations(us_motion_exp_dir, us_motion_ref_h5_path, us_motion_h5_results_fn, ind, x, y, n)
plt.savefig('undersample_motion_imgs.pdf',bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(1,4, figsize=(16,4), gridspec_kw = {'wspace':0.2, 'hspace':0.2})
plot_ablations_scatter(motion_exp_dir, motion_h5_results_fn,ax=axes[0],ylabel=True, exp_title='Motion')
plot_ablations_scatter(noise_exp_dir, noise_h5_results_fn,ax=axes[1], exp_title='Noise')
plot_ablations_scatter(undersample_exp_dir, undersample_h5_results_fn,ax=axes[2], exp_title='Undersampling')
plot_ablations_scatter(us_motion_exp_dir, us_motion_h5_results_fn,ax=axes[3], show_legend=True, exp_title='Undersampled Motion')
plt.savefig('scatterplots.pdf',bbox_inches='tight')